From f9cd0e85b27f228de21e101b1672c72f63c80717 Mon Sep 17 00:00:00 2001 From: zhouyang Date: Sun, 23 Jan 2022 17:55:00 +0800 Subject: [PATCH] add feat of kaldi --- speechx/CMakeLists.txt | 41 +- speechx/speechx/CMakeLists.txt | 14 + speechx/speechx/codelab/README.md | 4 + .../codelab/feat_test/feature-mfcc-test.cc | 686 ++++ speechx/speechx/common/CMakeLists.txt | 0 speechx/speechx/kaldi/CMakeLists.txt | 6 + speechx/speechx/kaldi/base/CMakeLists.txt | 7 + speechx/speechx/kaldi/base/io-funcs-inl.h | 327 ++ speechx/speechx/kaldi/base/io-funcs.cc | 218 ++ speechx/speechx/kaldi/base/io-funcs.h | 245 ++ speechx/speechx/kaldi/base/kaldi-common.h | 41 + speechx/speechx/kaldi/base/kaldi-error.cc | 245 ++ speechx/speechx/kaldi/base/kaldi-error.h | 231 ++ speechx/speechx/kaldi/base/kaldi-math.cc | 162 + speechx/speechx/kaldi/base/kaldi-math.h | 363 ++ speechx/speechx/kaldi/base/kaldi-types.h | 76 + speechx/speechx/kaldi/base/kaldi-utils.cc | 55 + speechx/speechx/kaldi/base/kaldi-utils.h | 155 + speechx/speechx/kaldi/base/timer.cc | 85 + speechx/speechx/kaldi/base/timer.h | 115 + speechx/speechx/kaldi/base/version.h | 4 + speechx/speechx/kaldi/feat/CMakeLists.txt | 19 + .../speechx/kaldi/feat/feature-common-inl.h | 99 + speechx/speechx/kaldi/feat/feature-common.h | 176 + speechx/speechx/kaldi/feat/feature-fbank.cc | 125 + speechx/speechx/kaldi/feat/feature-fbank.h | 149 + .../speechx/kaldi/feat/feature-functions.cc | 362 ++ .../speechx/kaldi/feat/feature-functions.h | 204 ++ speechx/speechx/kaldi/feat/feature-mfcc.cc | 157 + speechx/speechx/kaldi/feat/feature-mfcc.h | 154 + speechx/speechx/kaldi/feat/feature-plp.cc | 191 + speechx/speechx/kaldi/feat/feature-plp.h | 176 + .../speechx/kaldi/feat/feature-spectrogram.cc | 82 + .../speechx/kaldi/feat/feature-spectrogram.h | 117 + speechx/speechx/kaldi/feat/feature-window.cc | 222 ++ speechx/speechx/kaldi/feat/feature-window.h | 223 ++ .../speechx/kaldi/feat/mel-computations.cc | 340 ++ speechx/speechx/kaldi/feat/mel-computations.h | 171 + speechx/speechx/kaldi/feat/online-feature.cc | 679 ++++ speechx/speechx/kaldi/feat/online-feature.h | 632 ++++ speechx/speechx/kaldi/feat/pitch-functions.cc | 1667 +++++++++ speechx/speechx/kaldi/feat/pitch-functions.h | 450 +++ speechx/speechx/kaldi/feat/resample.cc | 377 ++ speechx/speechx/kaldi/feat/resample.h | 287 ++ speechx/speechx/kaldi/feat/signal.cc | 129 + speechx/speechx/kaldi/feat/signal.h | 58 + speechx/speechx/kaldi/feat/wave-reader.cc | 387 ++ speechx/speechx/kaldi/feat/wave-reader.h | 248 ++ speechx/speechx/kaldi/matrix/BUILD | 39 + speechx/speechx/kaldi/matrix/CMakeLists.txt | 16 + speechx/speechx/kaldi/matrix/cblas-wrappers.h | 491 +++ .../speechx/kaldi/matrix/compressed-matrix.cc | 876 +++++ .../speechx/kaldi/matrix/compressed-matrix.h | 283 ++ speechx/speechx/kaldi/matrix/jama-eig.h | 924 +++++ speechx/speechx/kaldi/matrix/jama-svd.h | 531 +++ speechx/speechx/kaldi/matrix/kaldi-blas.h | 133 + .../speechx/kaldi/matrix/kaldi-matrix-inl.h | 63 + speechx/speechx/kaldi/matrix/kaldi-matrix.cc | 3103 +++++++++++++++++ speechx/speechx/kaldi/matrix/kaldi-matrix.h | 1122 ++++++ .../speechx/kaldi/matrix/kaldi-vector-inl.h | 58 + speechx/speechx/kaldi/matrix/kaldi-vector.cc | 1355 +++++++ speechx/speechx/kaldi/matrix/kaldi-vector.h | 612 ++++ speechx/speechx/kaldi/matrix/matrix-common.h | 111 + .../kaldi/matrix/matrix-functions-inl.h | 56 + .../speechx/kaldi/matrix/matrix-functions.cc | 773 ++++ .../speechx/kaldi/matrix/matrix-functions.h | 174 + speechx/speechx/kaldi/matrix/matrix-lib.h | 37 + speechx/speechx/kaldi/matrix/optimization.cc | 577 +++ speechx/speechx/kaldi/matrix/optimization.h | 248 ++ speechx/speechx/kaldi/matrix/packed-matrix.cc | 438 +++ speechx/speechx/kaldi/matrix/packed-matrix.h | 197 ++ speechx/speechx/kaldi/matrix/qr.cc | 580 +++ speechx/speechx/kaldi/matrix/sp-matrix-inl.h | 42 + speechx/speechx/kaldi/matrix/sp-matrix.cc | 1216 +++++++ speechx/speechx/kaldi/matrix/sp-matrix.h | 517 +++ speechx/speechx/kaldi/matrix/sparse-matrix.cc | 1296 +++++++ speechx/speechx/kaldi/matrix/sparse-matrix.h | 452 +++ speechx/speechx/kaldi/matrix/srfft.cc | 440 +++ speechx/speechx/kaldi/matrix/srfft.h | 141 + speechx/speechx/kaldi/matrix/tp-matrix.cc | 145 + speechx/speechx/kaldi/matrix/tp-matrix.h | 134 + speechx/speechx/kaldi/util/CMakeLists.txt | 12 + speechx/speechx/kaldi/util/basic-filebuf.h | 994 ++++++ speechx/speechx/kaldi/util/common-utils.h | 31 + .../kaldi/util/const-integer-set-inl.h | 91 + .../speechx/kaldi/util/const-integer-set.h | 96 + .../speechx/kaldi/util/edit-distance-inl.h | 200 ++ speechx/speechx/kaldi/util/edit-distance.h | 64 + speechx/speechx/kaldi/util/hash-list-inl.h | 194 ++ speechx/speechx/kaldi/util/hash-list.h | 147 + .../speechx/kaldi/util/kaldi-cygwin-io-inl.h | 129 + speechx/speechx/kaldi/util/kaldi-holder-inl.h | 922 +++++ speechx/speechx/kaldi/util/kaldi-holder.cc | 229 ++ speechx/speechx/kaldi/util/kaldi-holder.h | 282 ++ speechx/speechx/kaldi/util/kaldi-io-inl.h | 46 + speechx/speechx/kaldi/util/kaldi-io.cc | 884 +++++ speechx/speechx/kaldi/util/kaldi-io.h | 280 ++ speechx/speechx/kaldi/util/kaldi-pipebuf.h | 87 + speechx/speechx/kaldi/util/kaldi-semaphore.cc | 57 + speechx/speechx/kaldi/util/kaldi-semaphore.h | 50 + speechx/speechx/kaldi/util/kaldi-table-inl.h | 2672 ++++++++++++++ speechx/speechx/kaldi/util/kaldi-table.cc | 321 ++ speechx/speechx/kaldi/util/kaldi-table.h | 471 +++ speechx/speechx/kaldi/util/kaldi-thread.cc | 33 + speechx/speechx/kaldi/util/kaldi-thread.h | 284 ++ speechx/speechx/kaldi/util/options-itf.h | 49 + speechx/speechx/kaldi/util/parse-options.cc | 668 ++++ speechx/speechx/kaldi/util/parse-options.h | 264 ++ speechx/speechx/kaldi/util/simple-io-funcs.cc | 81 + speechx/speechx/kaldi/util/simple-io-funcs.h | 63 + speechx/speechx/kaldi/util/simple-options.cc | 184 + speechx/speechx/kaldi/util/simple-options.h | 113 + speechx/speechx/kaldi/util/stl-utils.h | 317 ++ speechx/speechx/kaldi/util/table-types.h | 192 + speechx/speechx/kaldi/util/text-utils.cc | 591 ++++ speechx/speechx/kaldi/util/text-utils.h | 281 ++ speechx/speechx/third_party/README.md | 4 + 117 files changed, 39505 insertions(+), 19 deletions(-) create mode 100644 speechx/speechx/codelab/README.md create mode 100644 speechx/speechx/codelab/feat_test/feature-mfcc-test.cc create mode 100644 speechx/speechx/common/CMakeLists.txt create mode 100644 speechx/speechx/kaldi/CMakeLists.txt create mode 100644 speechx/speechx/kaldi/base/CMakeLists.txt create mode 100644 speechx/speechx/kaldi/base/io-funcs-inl.h create mode 100644 speechx/speechx/kaldi/base/io-funcs.cc create mode 100644 speechx/speechx/kaldi/base/io-funcs.h create mode 100644 speechx/speechx/kaldi/base/kaldi-common.h create mode 100644 speechx/speechx/kaldi/base/kaldi-error.cc create mode 100644 speechx/speechx/kaldi/base/kaldi-error.h create mode 100644 speechx/speechx/kaldi/base/kaldi-math.cc create mode 100644 speechx/speechx/kaldi/base/kaldi-math.h create mode 100644 speechx/speechx/kaldi/base/kaldi-types.h create mode 100644 speechx/speechx/kaldi/base/kaldi-utils.cc create mode 100644 speechx/speechx/kaldi/base/kaldi-utils.h create mode 100644 speechx/speechx/kaldi/base/timer.cc create mode 100644 speechx/speechx/kaldi/base/timer.h create mode 100644 speechx/speechx/kaldi/base/version.h create mode 100644 speechx/speechx/kaldi/feat/CMakeLists.txt create mode 100644 speechx/speechx/kaldi/feat/feature-common-inl.h create mode 100644 speechx/speechx/kaldi/feat/feature-common.h create mode 100644 speechx/speechx/kaldi/feat/feature-fbank.cc create mode 100644 speechx/speechx/kaldi/feat/feature-fbank.h create mode 100644 speechx/speechx/kaldi/feat/feature-functions.cc create mode 100644 speechx/speechx/kaldi/feat/feature-functions.h create mode 100644 speechx/speechx/kaldi/feat/feature-mfcc.cc create mode 100644 speechx/speechx/kaldi/feat/feature-mfcc.h create mode 100644 speechx/speechx/kaldi/feat/feature-plp.cc create mode 100644 speechx/speechx/kaldi/feat/feature-plp.h create mode 100644 speechx/speechx/kaldi/feat/feature-spectrogram.cc create mode 100644 speechx/speechx/kaldi/feat/feature-spectrogram.h create mode 100644 speechx/speechx/kaldi/feat/feature-window.cc create mode 100644 speechx/speechx/kaldi/feat/feature-window.h create mode 100644 speechx/speechx/kaldi/feat/mel-computations.cc create mode 100644 speechx/speechx/kaldi/feat/mel-computations.h create mode 100644 speechx/speechx/kaldi/feat/online-feature.cc create mode 100644 speechx/speechx/kaldi/feat/online-feature.h create mode 100644 speechx/speechx/kaldi/feat/pitch-functions.cc create mode 100644 speechx/speechx/kaldi/feat/pitch-functions.h create mode 100644 speechx/speechx/kaldi/feat/resample.cc create mode 100644 speechx/speechx/kaldi/feat/resample.h create mode 100644 speechx/speechx/kaldi/feat/signal.cc create mode 100644 speechx/speechx/kaldi/feat/signal.h create mode 100644 speechx/speechx/kaldi/feat/wave-reader.cc create mode 100644 speechx/speechx/kaldi/feat/wave-reader.h create mode 100644 speechx/speechx/kaldi/matrix/BUILD create mode 100644 speechx/speechx/kaldi/matrix/CMakeLists.txt create mode 100644 speechx/speechx/kaldi/matrix/cblas-wrappers.h create mode 100644 speechx/speechx/kaldi/matrix/compressed-matrix.cc create mode 100644 speechx/speechx/kaldi/matrix/compressed-matrix.h create mode 100644 speechx/speechx/kaldi/matrix/jama-eig.h create mode 100644 speechx/speechx/kaldi/matrix/jama-svd.h create mode 100644 speechx/speechx/kaldi/matrix/kaldi-blas.h create mode 100644 speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h create mode 100644 speechx/speechx/kaldi/matrix/kaldi-matrix.cc create mode 100644 speechx/speechx/kaldi/matrix/kaldi-matrix.h create mode 100644 speechx/speechx/kaldi/matrix/kaldi-vector-inl.h create mode 100644 speechx/speechx/kaldi/matrix/kaldi-vector.cc create mode 100644 speechx/speechx/kaldi/matrix/kaldi-vector.h create mode 100644 speechx/speechx/kaldi/matrix/matrix-common.h create mode 100644 speechx/speechx/kaldi/matrix/matrix-functions-inl.h create mode 100644 speechx/speechx/kaldi/matrix/matrix-functions.cc create mode 100644 speechx/speechx/kaldi/matrix/matrix-functions.h create mode 100644 speechx/speechx/kaldi/matrix/matrix-lib.h create mode 100644 speechx/speechx/kaldi/matrix/optimization.cc create mode 100644 speechx/speechx/kaldi/matrix/optimization.h create mode 100644 speechx/speechx/kaldi/matrix/packed-matrix.cc create mode 100644 speechx/speechx/kaldi/matrix/packed-matrix.h create mode 100644 speechx/speechx/kaldi/matrix/qr.cc create mode 100644 speechx/speechx/kaldi/matrix/sp-matrix-inl.h create mode 100644 speechx/speechx/kaldi/matrix/sp-matrix.cc create mode 100644 speechx/speechx/kaldi/matrix/sp-matrix.h create mode 100644 speechx/speechx/kaldi/matrix/sparse-matrix.cc create mode 100644 speechx/speechx/kaldi/matrix/sparse-matrix.h create mode 100644 speechx/speechx/kaldi/matrix/srfft.cc create mode 100644 speechx/speechx/kaldi/matrix/srfft.h create mode 100644 speechx/speechx/kaldi/matrix/tp-matrix.cc create mode 100644 speechx/speechx/kaldi/matrix/tp-matrix.h create mode 100644 speechx/speechx/kaldi/util/CMakeLists.txt create mode 100644 speechx/speechx/kaldi/util/basic-filebuf.h create mode 100644 speechx/speechx/kaldi/util/common-utils.h create mode 100644 speechx/speechx/kaldi/util/const-integer-set-inl.h create mode 100644 speechx/speechx/kaldi/util/const-integer-set.h create mode 100644 speechx/speechx/kaldi/util/edit-distance-inl.h create mode 100644 speechx/speechx/kaldi/util/edit-distance.h create mode 100644 speechx/speechx/kaldi/util/hash-list-inl.h create mode 100644 speechx/speechx/kaldi/util/hash-list.h create mode 100644 speechx/speechx/kaldi/util/kaldi-cygwin-io-inl.h create mode 100644 speechx/speechx/kaldi/util/kaldi-holder-inl.h create mode 100644 speechx/speechx/kaldi/util/kaldi-holder.cc create mode 100644 speechx/speechx/kaldi/util/kaldi-holder.h create mode 100644 speechx/speechx/kaldi/util/kaldi-io-inl.h create mode 100644 speechx/speechx/kaldi/util/kaldi-io.cc create mode 100644 speechx/speechx/kaldi/util/kaldi-io.h create mode 100644 speechx/speechx/kaldi/util/kaldi-pipebuf.h create mode 100644 speechx/speechx/kaldi/util/kaldi-semaphore.cc create mode 100644 speechx/speechx/kaldi/util/kaldi-semaphore.h create mode 100644 speechx/speechx/kaldi/util/kaldi-table-inl.h create mode 100644 speechx/speechx/kaldi/util/kaldi-table.cc create mode 100644 speechx/speechx/kaldi/util/kaldi-table.h create mode 100644 speechx/speechx/kaldi/util/kaldi-thread.cc create mode 100644 speechx/speechx/kaldi/util/kaldi-thread.h create mode 100644 speechx/speechx/kaldi/util/options-itf.h create mode 100644 speechx/speechx/kaldi/util/parse-options.cc create mode 100644 speechx/speechx/kaldi/util/parse-options.h create mode 100644 speechx/speechx/kaldi/util/simple-io-funcs.cc create mode 100644 speechx/speechx/kaldi/util/simple-io-funcs.h create mode 100644 speechx/speechx/kaldi/util/simple-options.cc create mode 100644 speechx/speechx/kaldi/util/simple-options.h create mode 100644 speechx/speechx/kaldi/util/stl-utils.h create mode 100644 speechx/speechx/kaldi/util/table-types.h create mode 100644 speechx/speechx/kaldi/util/text-utils.cc create mode 100644 speechx/speechx/kaldi/util/text-utils.h create mode 100644 speechx/speechx/third_party/README.md diff --git a/speechx/CMakeLists.txt b/speechx/CMakeLists.txt index 878374ba..12dc594f 100644 --- a/speechx/CMakeLists.txt +++ b/speechx/CMakeLists.txt @@ -49,29 +49,32 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(libsndfile) +add_subdirectory(speechx) + +#openblas +#set(OpenBLAS_INSTALL_PREFIX ${fc_patch}/OpenBLAS) +#set(OpenBLAS_SOURCE_DIR ${fc_patch}/OpenBLAS-src) +#ExternalProject_Add( +# OpenBLAS +# GIT_REPOSITORY https://github.com/xianyi/OpenBLAS +# GIT_TAG v0.3.13 +# GIT_SHALLOW TRUE +# GIT_PROGRESS TRUE +# CONFIGURE_COMMAND "" +# BUILD_IN_SOURCE TRUE +# BUILD_COMMAND make USE_LOCKING=1 USE_THREAD=0 +# INSTALL_COMMAND make PREFIX=${OpenBLAS_INSTALL_PREFIX} install +# UPDATE_DISCONNECTED TRUE +#) ############################################################################### # Add local library ############################################################################### # system lib -find_package() +#find_package() # if dir have CmakeLists.txt -add_subdirectory() +#add_subdirectory(speechx) # if dir do not have CmakeLists.txt -add_library(lib_name STATIC file.cc) -target_link_libraries(lib_name item0 item1) -add_dependencies(lib_name depend-target) - - -############################################################################### -# Library installation -############################################################################### -install() - - -############################################################################### -# Build binary file -############################################################################### -add_executable() -target_link_libraries() - +#add_library(lib_name STATIC file.cc) +#target_link_libraries(lib_name item0 item1) +#add_dependencies(lib_name depend-target) \ No newline at end of file diff --git a/speechx/speechx/CMakeLists.txt b/speechx/speechx/CMakeLists.txt index e69de29b..71c7eb7c 100644 --- a/speechx/speechx/CMakeLists.txt +++ b/speechx/speechx/CMakeLists.txt @@ -0,0 +1,14 @@ +cmake_minimum_required(VERSION 3.14 FATAL_ERROR) + +project(speechx LANGUAGES CXX) + +link_directories(${CMAKE_CURRENT_SOURCE_DIR}/third_party/openblas) + +include_directories( +${CMAKE_CURRENT_SOURCE_DIR} +${CMAKE_CURRENT_SOURCE_DIR}/kaldi +) +add_subdirectory(kaldi) + +add_executable(mfcc-test codelab/feat_test/feature-mfcc-test.cc) +target_link_libraries(mfcc-test kaldi-mfcc) diff --git a/speechx/speechx/codelab/README.md b/speechx/speechx/codelab/README.md new file mode 100644 index 00000000..95c95db1 --- /dev/null +++ b/speechx/speechx/codelab/README.md @@ -0,0 +1,4 @@ +# codelab + +This directory is here for testing some funcitons temporaril. + diff --git a/speechx/speechx/codelab/feat_test/feature-mfcc-test.cc b/speechx/speechx/codelab/feat_test/feature-mfcc-test.cc new file mode 100644 index 00000000..c4367139 --- /dev/null +++ b/speechx/speechx/codelab/feat_test/feature-mfcc-test.cc @@ -0,0 +1,686 @@ +// feat/feature-mfcc-test.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include + +#include "feat/feature-mfcc.h" +#include "base/kaldi-math.h" +#include "matrix/kaldi-matrix-inl.h" +#include "feat/wave-reader.h" + +using namespace kaldi; + + + +static void UnitTestReadWave() { + + std::cout << "=== UnitTestReadWave() ===\n"; + + Vector v, v2; + + std::cout << "<<<=== Reading waveform\n"; + + { + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + const Matrix data(wave.Data()); + KALDI_ASSERT(data.NumRows() == 1); + v.Resize(data.NumCols()); + v.CopyFromVec(data.Row(0)); + } + + std::cout << "<<<=== Reading Vector waveform, prepared by matlab\n"; + std::ifstream input( + "test_data/test_matlab.ascii" + ); + KALDI_ASSERT(input.good()); + v2.Read(input, false); + input.close(); + + std::cout << "<<<=== Comparing freshly read waveform to 'libsndfile' waveform\n"; + KALDI_ASSERT(v.Dim() == v2.Dim()); + for (int32 i = 0; i < v.Dim(); i++) { + KALDI_ASSERT(v(i) == v2(i)); + } + std::cout << "<<<=== Comparing done\n"; + + // std::cout << "== The Waveform Samples == \n"; + // std::cout << v; + + std::cout << "Test passed :)\n\n"; + +} + + + +/** + */ +static void UnitTestSimple() { + std::cout << "=== UnitTestSimple() ===\n"; + + Vector v(100000); + Matrix m; + + // init with noise + for (int32 i = 0; i < v.Dim(); i++) { + v(i) = (abs( i * 433024253 ) % 65535) - (65535 / 2); + } + + std::cout << "<<<=== Just make sure it runs... Nothing is compared\n"; + // the parametrization object + MfccOptions op; + // trying to have same opts as baseline. + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "rectangular"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.mel_opts.htk_mode = true; + op.htk_compat = true; + + Mfcc mfcc(op); + // use default parameters + + // compute mfccs. + mfcc.Compute(v, 1.0, &m); + + // possibly dump + // std::cout << "== Output features == \n" << m; + std::cout << "Test passed :)\n\n"; +} + + +static void UnitTestHTKCompare1() { + std::cout << "=== UnitTestHTKCompare1() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.1", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.mel_opts.htk_mode = true; + op.htk_compat = true; + op.use_energy = false; // C0 not energy. + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, + kaldi_raw_features, + &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (i_old != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + }}} + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float)*kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.1", + std::ios::out|std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.1"); +} + + +static void UnitTestHTKCompare2() { + std::cout << "=== UnitTestHTKCompare2() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.2", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.mel_opts.htk_mode = true; + op.htk_compat = true; + op.use_energy = true; // Use energy. + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, + kaldi_raw_features, + &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (i_old != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + }}} + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float)*kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.2", + std::ios::out|std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.2"); +} + + +static void UnitTestHTKCompare3() { + std::cout << "=== UnitTestHTKCompare3() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.3", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.htk_compat = true; + op.use_energy = true; // Use energy. + op.mel_opts.low_freq = 20.0; + //op.mel_opts.debug_mel = true; + op.mel_opts.htk_mode = true; + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, + kaldi_raw_features, + &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + }}} + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float)*kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.3", + std::ios::out|std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.3"); +} + + +static void UnitTestHTKCompare4() { + std::cout << "=== UnitTestHTKCompare4() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.4", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.low_freq = 0.0; + op.htk_compat = true; + op.use_energy = true; // Use energy. + op.mel_opts.htk_mode = true; + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, + kaldi_raw_features, + &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + }}} + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float)*kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.4", + std::ios::out|std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.4"); +} + + +static void UnitTestHTKCompare5() { + std::cout << "=== UnitTestHTKCompare5() ===\n"; + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.5", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.htk_compat = true; + op.use_energy = true; // Use energy. + op.mel_opts.low_freq = 0.0; + op.mel_opts.vtln_low = 100.0; + op.mel_opts.vtln_high = 7500.0; + op.mel_opts.htk_mode = true; + + BaseFloat vtln_warp = 1.1; // our approach identical to htk for warp factor >1, + // differs slightly for higher mel bins if warp_factor <0.9 + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, vtln_warp, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, + kaldi_raw_features, + &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + }}} + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float)*kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.5", + std::ios::out|std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.5"); +} + +static void UnitTestHTKCompare6() { + std::cout << "=== UnitTestHTKCompare6() ===\n"; + + + std::ifstream is("test_data/test.wav", std::ios_base::binary); + WaveData wave; + wave.Read(is); + KALDI_ASSERT(wave.Data().NumRows() == 1); + SubVector waveform(wave.Data(), 0); + + // read the HTK features + Matrix htk_features; + { + std::ifstream is("test_data/test.wav.fea_htk.6", + std::ios::in | std::ios_base::binary); + bool ans = ReadHtk(is, &htk_features, 0); + KALDI_ASSERT(ans); + } + + // use mfcc with default configuration... + MfccOptions op; + op.frame_opts.dither = 0.0; + op.frame_opts.preemph_coeff = 0.97; + op.frame_opts.window_type = "hamming"; + op.frame_opts.remove_dc_offset = false; + op.frame_opts.round_to_power_of_two = true; + op.mel_opts.num_bins = 24; + op.mel_opts.low_freq = 125.0; + op.mel_opts.high_freq = 7800.0; + op.htk_compat = true; + op.use_energy = false; // C0 not energy. + + Mfcc mfcc(op); + + // calculate kaldi features + Matrix kaldi_raw_features; + mfcc.Compute(waveform, 1.0, &kaldi_raw_features); + + DeltaFeaturesOptions delta_opts; + Matrix kaldi_features; + ComputeDeltas(delta_opts, + kaldi_raw_features, + &kaldi_features); + + // compare the results + bool passed = true; + int32 i_old = -1; + KALDI_ASSERT(kaldi_features.NumRows() == htk_features.NumRows()); + KALDI_ASSERT(kaldi_features.NumCols() == htk_features.NumCols()); + // Ignore ends-- we make slightly different choices than + // HTK about how to treat the deltas at the ends. + for (int32 i = 10; i+10 < kaldi_features.NumRows(); i++) { + for (int32 j = 0; j < kaldi_features.NumCols(); j++) { + BaseFloat a = kaldi_features(i, j), b = htk_features(i, j); + if ((std::abs(b - a)) > 1.0) { //<< TOLERANCE TO DIFFERENCES!!!!! + // print the non-matching data only once per-line + if (static_cast(i_old) != i) { + std::cout << "\n\n\n[HTK-row: " << i << "] " << htk_features.Row(i) << "\n"; + std::cout << "[Kaldi-row: " << i << "] " << kaldi_features.Row(i) << "\n\n\n"; + i_old = i; + } + // print indices of non-matching cells + std::cout << "[" << i << ", " << j << "]"; + passed = false; + }}} + if (!passed) KALDI_ERR << "Test failed"; + + // write the htk features for later inspection + HtkHeader header = { + kaldi_features.NumRows(), + 100000, // 10ms + static_cast(sizeof(float)*kaldi_features.NumCols()), + 021406 // MFCC_D_A_0 + }; + { + std::ofstream os("tmp.test.wav.fea_kaldi.6", + std::ios::out|std::ios::binary); + WriteHtk(os, kaldi_features, header); + } + + std::cout << "Test passed :)\n\n"; + + unlink("tmp.test.wav.fea_kaldi.6"); +} + +void UnitTestVtln() { + // Test the function VtlnWarpFreq. + BaseFloat low_freq = 10, high_freq = 7800, + vtln_low_cutoff = 20, vtln_high_cutoff = 7400; + + for (size_t i = 0; i < 100; i++) { + BaseFloat freq = 5000, warp_factor = 0.9 + RandUniform() * 0.2; + AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, + low_freq, high_freq, warp_factor, + freq), + freq / warp_factor); + + AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, + low_freq, high_freq, warp_factor, + low_freq), + low_freq); + AssertEqual(MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, + low_freq, high_freq, warp_factor, + high_freq), + high_freq); + BaseFloat freq2 = low_freq + (high_freq-low_freq) * RandUniform(), + freq3 = freq2 + (high_freq-freq2) * RandUniform(); // freq3>=freq2 + BaseFloat w2 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, + low_freq, high_freq, warp_factor, + freq2); + BaseFloat w3 = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, + low_freq, high_freq, warp_factor, + freq3); + KALDI_ASSERT(w3 >= w2); // increasing function. + BaseFloat w3dash = MelBanks::VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, + low_freq, high_freq, 1.0, + freq3); + AssertEqual(w3dash, freq3); + } +} + +static void UnitTestFeat() { + UnitTestVtln(); + UnitTestReadWave(); + UnitTestSimple(); + UnitTestHTKCompare1(); + UnitTestHTKCompare2(); + // commenting out this one as it doesn't compare right now I normalized + // the way the FFT bins are treated (removed offset of 0.5)... this seems + // to relate to the way frequency zero behaves. + UnitTestHTKCompare3(); + UnitTestHTKCompare4(); + UnitTestHTKCompare5(); + UnitTestHTKCompare6(); + std::cout << "Tests succeeded.\n"; +} + + + +int main() { + try { + for (int i = 0; i < 5; i++) + UnitTestFeat(); + std::cout << "Tests succeeded.\n"; + return 0; + } catch (const std::exception &e) { + std::cerr << e.what(); + return 1; + } +} + + diff --git a/speechx/speechx/common/CMakeLists.txt b/speechx/speechx/common/CMakeLists.txt new file mode 100644 index 00000000..e69de29b diff --git a/speechx/speechx/kaldi/CMakeLists.txt b/speechx/speechx/kaldi/CMakeLists.txt new file mode 100644 index 00000000..414a6fa0 --- /dev/null +++ b/speechx/speechx/kaldi/CMakeLists.txt @@ -0,0 +1,6 @@ +project(kaldi) + +add_subdirectory(base) +add_subdirectory(util) +add_subdirectory(feat) +add_subdirectory(matrix) diff --git a/speechx/speechx/kaldi/base/CMakeLists.txt b/speechx/speechx/kaldi/base/CMakeLists.txt new file mode 100644 index 00000000..f738bf2d --- /dev/null +++ b/speechx/speechx/kaldi/base/CMakeLists.txt @@ -0,0 +1,7 @@ + +add_library(kaldi-base + io-funcs.cc + kaldi-error.cc + kaldi-math.cc + kaldi-utils.cc + timer.cc) \ No newline at end of file diff --git a/speechx/speechx/kaldi/base/io-funcs-inl.h b/speechx/speechx/kaldi/base/io-funcs-inl.h new file mode 100644 index 00000000..b703ef5a --- /dev/null +++ b/speechx/speechx/kaldi/base/io-funcs-inl.h @@ -0,0 +1,327 @@ +// base/io-funcs-inl.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian; +// Johns Hopkins University (Author: Daniel Povey) +// 2016 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_IO_FUNCS_INL_H_ +#define KALDI_BASE_IO_FUNCS_INL_H_ 1 + +// Do not include this file directly. It is included by base/io-funcs.h + +#include +#include + +namespace kaldi { + +// Template that covers integers. +template void WriteBasicType(std::ostream &os, + bool binary, T t) { + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + char len_c = (std::numeric_limits::is_signed ? 1 : -1) + * static_cast(sizeof(t)); + os.put(len_c); + os.write(reinterpret_cast(&t), sizeof(t)); + } else { + if (sizeof(t) == 1) + os << static_cast(t) << " "; + else + os << t << " "; + } + if (os.fail()) { + KALDI_ERR << "Write failure in WriteBasicType."; + } +} + +// Template that covers integers. +template inline void ReadBasicType(std::istream &is, + bool binary, T *t) { + KALDI_PARANOID_ASSERT(t != NULL); + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + int len_c_in = is.get(); + if (len_c_in == -1) + KALDI_ERR << "ReadBasicType: encountered end of stream."; + char len_c = static_cast(len_c_in), len_c_expected + = (std::numeric_limits::is_signed ? 1 : -1) + * static_cast(sizeof(*t)); + if (len_c != len_c_expected) { + KALDI_ERR << "ReadBasicType: did not get expected integer type, " + << static_cast(len_c) + << " vs. " << static_cast(len_c_expected) + << ". You can change this code to successfully" + << " read it later, if needed."; + // insert code here to read "wrong" type. Might have a switch statement. + } + is.read(reinterpret_cast(t), sizeof(*t)); + } else { + if (sizeof(*t) == 1) { + int16 i; + is >> i; + *t = i; + } else { + is >> *t; + } + } + if (is.fail()) { + KALDI_ERR << "Read failure in ReadBasicType, file position is " + << is.tellg() << ", next char is " << is.peek(); + } +} + +// Template that covers integers. +template +inline void WriteIntegerPairVector(std::ostream &os, bool binary, + const std::vector > &v) { + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + char sz = sizeof(T); // this is currently just a check. + os.write(&sz, 1); + int32 vecsz = static_cast(v.size()); + KALDI_ASSERT((size_t)vecsz == v.size()); + os.write(reinterpret_cast(&vecsz), sizeof(vecsz)); + if (vecsz != 0) { + os.write(reinterpret_cast(&(v[0])), sizeof(T) * vecsz * 2); + } + } else { + // focus here is on prettiness of text form rather than + // efficiency of reading-in. + // reading-in is dominated by low-level operations anyway: + // for efficiency use binary. + os << "[ "; + typename std::vector >::const_iterator iter = v.begin(), + end = v.end(); + for (; iter != end; ++iter) { + if (sizeof(T) == 1) + os << static_cast(iter->first) << ',' + << static_cast(iter->second) << ' '; + else + os << iter->first << ',' + << iter->second << ' '; + } + os << "]\n"; + } + if (os.fail()) { + KALDI_ERR << "Write failure in WriteIntegerPairVector."; + } +} + +// Template that covers integers. +template +inline void ReadIntegerPairVector(std::istream &is, bool binary, + std::vector > *v) { + KALDI_ASSERT_IS_INTEGER_TYPE(T); + KALDI_ASSERT(v != NULL); + if (binary) { + int sz = is.peek(); + if (sz == sizeof(T)) { + is.get(); + } else { // this is currently just a check. + KALDI_ERR << "ReadIntegerPairVector: expected to see type of size " + << sizeof(T) << ", saw instead " << sz << ", at file position " + << is.tellg(); + } + int32 vecsz; + is.read(reinterpret_cast(&vecsz), sizeof(vecsz)); + if (is.fail() || vecsz < 0) goto bad; + v->resize(vecsz); + if (vecsz > 0) { + is.read(reinterpret_cast(&((*v)[0])), sizeof(T)*vecsz*2); + } + } else { + std::vector > tmp_v; // use temporary so v doesn't use extra memory + // due to resizing. + is >> std::ws; + if (is.peek() != static_cast('[')) { + KALDI_ERR << "ReadIntegerPairVector: expected to see [, saw " + << is.peek() << ", at file position " << is.tellg(); + } + is.get(); // consume the '['. + is >> std::ws; // consume whitespace. + while (is.peek() != static_cast(']')) { + if (sizeof(T) == 1) { // read/write chars as numbers. + int16 next_t1, next_t2; + is >> next_t1; + if (is.fail()) goto bad; + if (is.peek() != static_cast(',')) + KALDI_ERR << "ReadIntegerPairVector: expected to see ',', saw " + << is.peek() << ", at file position " << is.tellg(); + is.get(); // consume the ','. + is >> next_t2 >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back(std::make_pair((T)next_t1, (T)next_t2)); + } else { + T next_t1, next_t2; + is >> next_t1; + if (is.fail()) goto bad; + if (is.peek() != static_cast(',')) + KALDI_ERR << "ReadIntegerPairVector: expected to see ',', saw " + << is.peek() << ", at file position " << is.tellg(); + is.get(); // consume the ','. + is >> next_t2 >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back(std::pair(next_t1, next_t2)); + } + } + is.get(); // get the final ']'. + *v = tmp_v; // could use std::swap to use less temporary memory, but this + // uses less permanent memory. + } + if (!is.fail()) return; + bad: + KALDI_ERR << "ReadIntegerPairVector: read failure at file position " + << is.tellg(); +} + +template inline void WriteIntegerVector(std::ostream &os, bool binary, + const std::vector &v) { + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + char sz = sizeof(T); // this is currently just a check. + os.write(&sz, 1); + int32 vecsz = static_cast(v.size()); + KALDI_ASSERT((size_t)vecsz == v.size()); + os.write(reinterpret_cast(&vecsz), sizeof(vecsz)); + if (vecsz != 0) { + os.write(reinterpret_cast(&(v[0])), sizeof(T)*vecsz); + } + } else { + // focus here is on prettiness of text form rather than + // efficiency of reading-in. + // reading-in is dominated by low-level operations anyway: + // for efficiency use binary. + os << "[ "; + typename std::vector::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) { + if (sizeof(T) == 1) + os << static_cast(*iter) << " "; + else + os << *iter << " "; + } + os << "]\n"; + } + if (os.fail()) { + KALDI_ERR << "Write failure in WriteIntegerVector."; + } +} + + +template inline void ReadIntegerVector(std::istream &is, + bool binary, + std::vector *v) { + KALDI_ASSERT_IS_INTEGER_TYPE(T); + KALDI_ASSERT(v != NULL); + if (binary) { + int sz = is.peek(); + if (sz == sizeof(T)) { + is.get(); + } else { // this is currently just a check. + KALDI_ERR << "ReadIntegerVector: expected to see type of size " + << sizeof(T) << ", saw instead " << sz << ", at file position " + << is.tellg(); + } + int32 vecsz; + is.read(reinterpret_cast(&vecsz), sizeof(vecsz)); + if (is.fail() || vecsz < 0) goto bad; + v->resize(vecsz); + if (vecsz > 0) { + is.read(reinterpret_cast(&((*v)[0])), sizeof(T)*vecsz); + } + } else { + std::vector tmp_v; // use temporary so v doesn't use extra memory + // due to resizing. + is >> std::ws; + if (is.peek() != static_cast('[')) { + KALDI_ERR << "ReadIntegerVector: expected to see [, saw " + << is.peek() << ", at file position " << is.tellg(); + } + is.get(); // consume the '['. + is >> std::ws; // consume whitespace. + while (is.peek() != static_cast(']')) { + if (sizeof(T) == 1) { // read/write chars as numbers. + int16 next_t; + is >> next_t >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back((T)next_t); + } else { + T next_t; + is >> next_t >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back(next_t); + } + } + is.get(); // get the final ']'. + *v = tmp_v; // could use std::swap to use less temporary memory, but this + // uses less permanent memory. + } + if (!is.fail()) return; + bad: + KALDI_ERR << "ReadIntegerVector: read failure at file position " + << is.tellg(); +} + + +// Initialize an opened stream for writing by writing an optional binary +// header and modifying the floating-point precision. +inline void InitKaldiOutputStream(std::ostream &os, bool binary) { + // This does not throw exceptions (does not check for errors). + if (binary) { + os.put('\0'); + os.put('B'); + } + // Note, in non-binary mode we may at some point want to mess with + // the precision a bit. + // 7 is a bit more than the precision of float.. + if (os.precision() < 7) + os.precision(7); +} + +/// Initialize an opened stream for reading by detecting the binary header and +// setting the "binary" value appropriately. +inline bool InitKaldiInputStream(std::istream &is, bool *binary) { + // Sets the 'binary' variable. + // Throws exception in the very unusual situation that stream + // starts with '\0' but not then 'B'. + + if (is.peek() == '\0') { // seems to be binary + is.get(); + if (is.peek() != 'B') { + return false; + } + is.get(); + *binary = true; + return true; + } else { + *binary = false; + return true; + } +} + +} // end namespace kaldi. + +#endif // KALDI_BASE_IO_FUNCS_INL_H_ diff --git a/speechx/speechx/kaldi/base/io-funcs.cc b/speechx/speechx/kaldi/base/io-funcs.cc new file mode 100644 index 00000000..150f7409 --- /dev/null +++ b/speechx/speechx/kaldi/base/io-funcs.cc @@ -0,0 +1,218 @@ +// base/io-funcs.cc + +// Copyright 2009-2011 Microsoft Corporation; Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/io-funcs.h" +#include "base/kaldi-math.h" + +namespace kaldi { + +template<> +void WriteBasicType(std::ostream &os, bool binary, bool b) { + os << (b ? "T":"F"); + if (!binary) os << " "; + if (os.fail()) + KALDI_ERR << "Write failure in WriteBasicType"; +} + +template<> +void ReadBasicType(std::istream &is, bool binary, bool *b) { + KALDI_PARANOID_ASSERT(b != NULL); + if (!binary) is >> std::ws; // eat up whitespace. + char c = is.peek(); + if (c == 'T') { + *b = true; + is.get(); + } else if (c == 'F') { + *b = false; + is.get(); + } else { + KALDI_ERR << "Read failure in ReadBasicType, file position is " + << is.tellg() << ", next char is " << CharToString(c); + } +} + +template<> +void WriteBasicType(std::ostream &os, bool binary, float f) { + if (binary) { + char c = sizeof(f); + os.put(c); + os.write(reinterpret_cast(&f), sizeof(f)); + } else { + os << f << " "; + } +} + +template<> +void WriteBasicType(std::ostream &os, bool binary, double f) { + if (binary) { + char c = sizeof(f); + os.put(c); + os.write(reinterpret_cast(&f), sizeof(f)); + } else { + os << f << " "; + } +} + +template<> +void ReadBasicType(std::istream &is, bool binary, float *f) { + KALDI_PARANOID_ASSERT(f != NULL); + if (binary) { + double d; + int c = is.peek(); + if (c == sizeof(*f)) { + is.get(); + is.read(reinterpret_cast(f), sizeof(*f)); + } else if (c == sizeof(d)) { + ReadBasicType(is, binary, &d); + *f = d; + } else { + KALDI_ERR << "ReadBasicType: expected float, saw " << is.peek() + << ", at file position " << is.tellg(); + } + } else { + is >> *f; + } + if (is.fail()) { + KALDI_ERR << "ReadBasicType: failed to read, at file position " + << is.tellg(); + } +} + +template<> +void ReadBasicType(std::istream &is, bool binary, double *d) { + KALDI_PARANOID_ASSERT(d != NULL); + if (binary) { + float f; + int c = is.peek(); + if (c == sizeof(*d)) { + is.get(); + is.read(reinterpret_cast(d), sizeof(*d)); + } else if (c == sizeof(f)) { + ReadBasicType(is, binary, &f); + *d = f; + } else { + KALDI_ERR << "ReadBasicType: expected float, saw " << is.peek() + << ", at file position " << is.tellg(); + } + } else { + is >> *d; + } + if (is.fail()) { + KALDI_ERR << "ReadBasicType: failed to read, at file position " + << is.tellg(); + } +} + +void CheckToken(const char *token) { + if (*token == '\0') + KALDI_ERR << "Token is empty (not a valid token)"; + const char *orig_token = token; + while (*token != '\0') { + if (::isspace(*token)) + KALDI_ERR << "Token is not a valid token (contains space): '" + << orig_token << "'"; + token++; + } +} + +void WriteToken(std::ostream &os, bool binary, const char *token) { + // binary mode is ignored; + // we use space as termination character in either case. + KALDI_ASSERT(token != NULL); + CheckToken(token); // make sure it's valid (can be read back) + os << token << " "; + if (os.fail()) { + KALDI_ERR << "Write failure in WriteToken."; + } +} + +int Peek(std::istream &is, bool binary) { + if (!binary) is >> std::ws; // eat up whitespace. + return is.peek(); +} + +void WriteToken(std::ostream &os, bool binary, const std::string & token) { + WriteToken(os, binary, token.c_str()); +} + +void ReadToken(std::istream &is, bool binary, std::string *str) { + KALDI_ASSERT(str != NULL); + if (!binary) is >> std::ws; // consume whitespace. + is >> *str; + if (is.fail()) { + KALDI_ERR << "ReadToken, failed to read token at file position " + << is.tellg(); + } + if (!isspace(is.peek())) { + KALDI_ERR << "ReadToken, expected space after token, saw instead " + << CharToString(static_cast(is.peek())) + << ", at file position " << is.tellg(); + } + is.get(); // consume the space. +} + +int PeekToken(std::istream &is, bool binary) { + if (!binary) is >> std::ws; // consume whitespace. + bool read_bracket; + if (static_cast(is.peek()) == '<') { + read_bracket = true; + is.get(); + } else { + read_bracket = false; + } + int ans = is.peek(); + if (read_bracket) { + if (!is.unget()) { + // Clear the bad bit. This code can be (and is in fact) reached, since the + // C++ standard does not guarantee that a call to unget() must succeed. + is.clear(); + } + } + return ans; +} + + +void ExpectToken(std::istream &is, bool binary, const char *token) { + int pos_at_start = is.tellg(); + KALDI_ASSERT(token != NULL); + CheckToken(token); // make sure it's valid (can be read back) + if (!binary) is >> std::ws; // consume whitespace. + std::string str; + is >> str; + is.get(); // consume the space. + if (is.fail()) { + KALDI_ERR << "Failed to read token [started at file position " + << pos_at_start << "], expected " << token; + } + // The second half of the '&&' expression below is so that if we're expecting + // "", we will accept "Foo>" instead. This is so that the model-reading + // code will tolerate errors in PeekToken where is.unget() failed; search for + // is.clear() in PeekToken() for an explanation. + if (strcmp(str.c_str(), token) != 0 && + !(token[0] == '<' && strcmp(str.c_str(), token + 1) == 0)) { + KALDI_ERR << "Expected token \"" << token << "\", got instead \"" + << str <<"\"."; + } +} + +void ExpectToken(std::istream &is, bool binary, const std::string &token) { + ExpectToken(is, binary, token.c_str()); +} + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/base/io-funcs.h b/speechx/speechx/kaldi/base/io-funcs.h new file mode 100644 index 00000000..895f661e --- /dev/null +++ b/speechx/speechx/kaldi/base/io-funcs.h @@ -0,0 +1,245 @@ +// base/io-funcs.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian +// 2016 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_IO_FUNCS_H_ +#define KALDI_BASE_IO_FUNCS_H_ + +// This header only contains some relatively low-level I/O functions. +// The full Kaldi I/O declarations are in ../util/kaldi-io.h +// and ../util/kaldi-table.h +// They were put in util/ in order to avoid making the Matrix library +// dependent on them. + +#include +#include +#include + +#include "base/kaldi-common.h" +#include "base/io-funcs-inl.h" + +namespace kaldi { + + + +/* + This comment describes the Kaldi approach to I/O. All objects can be written + and read in two modes: binary and text. In addition we want to make the I/O + work if we redefine the typedef "BaseFloat" between floats and doubles. + We also want to have control over whitespace in text mode without affecting + the meaning of the file, for pretty-printing purposes. + + Errors are handled by throwing a KaldiFatalError exception. + + For integer and floating-point types (and boolean values): + + WriteBasicType(std::ostream &, bool binary, const T&); + ReadBasicType(std::istream &, bool binary, T*); + + and we expect these functions to be defined in such a way that they work when + the type T changes between float and double, so you can read float into double + and vice versa]. Note that for efficiency and space-saving reasons, the Vector + and Matrix classes do not use these functions [but they preserve the type + interchangeability in their own way] + + For a class (or struct) C: + class C { + .. + Write(std::ostream &, bool binary, [possibly extra optional args for specific classes]) const; + Read(std::istream &, bool binary, [possibly extra optional args for specific classes]); + .. + } + NOTE: The only actual optional args we used are the "add" arguments in + Vector/Matrix classes, which specify whether we should sum the data already + in the class with the data being read. + + For types which are typedef's involving stl classes, I/O is as follows: + typedef std::vector > MyTypedefName; + + The user should define something like: + + WriteMyTypedefName(std::ostream &, bool binary, const MyTypedefName &t); + ReadMyTypedefName(std::ostream &, bool binary, MyTypedefName *t); + + The user would have to write these functions. + + For a type std::vector: + + void WriteIntegerVector(std::ostream &os, bool binary, const std::vector &v); + void ReadIntegerVector(std::istream &is, bool binary, std::vector *v); + + For other types, e.g. vectors of pairs, the user should create a routine of the + type WriteMyTypedefName. This is to avoid introducing confusing templated functions; + we could easily create templated functions to handle most of these cases but they + would have to share the same name. + + It also often happens that the user needs to write/read special tokens as part + of a file. These might be class headers, or separators/identifiers in the class. + We provide special functions for manipulating these. These special tokens must + be nonempty and must not contain any whitespace. + + void WriteToken(std::ostream &os, bool binary, const char*); + void WriteToken(std::ostream &os, bool binary, const std::string & token); + int Peek(std::istream &is, bool binary); + void ReadToken(std::istream &is, bool binary, std::string *str); + void PeekToken(std::istream &is, bool binary, std::string *str); + + WriteToken writes the token and one space (whether in binary or text mode). + + Peek returns the first character of the next token, by consuming whitespace + (in text mode) and then returning the peek() character. It returns -1 at EOF; + it doesn't throw. It's useful if a class can have various forms based on + typedefs and virtual classes, and wants to know which version to read. + + ReadToken allows the caller to obtain the next token. PeekToken works just + like ReadToken, but seeks back to the beginning of the token. A subsequent + call to ReadToken will read the same token again. This is useful when + different object types are written to the same file; using PeekToken one can + decide which of the objects to read. + + There is currently no special functionality for writing/reading strings (where the strings + contain data rather than "special tokens" that are whitespace-free and nonempty). This is + because Kaldi is structured in such a way that strings don't appear, except as OpenFst symbol + table entries (and these have their own format). + + + NOTE: you should not call ReadIntegerType and WriteIntegerType with types, + such as int and size_t, that are machine-independent -- at least not + if you want your file formats to port between machines. Use int32 and + int64 where necessary. There is no way to detect this using compile-time + assertions because C++ only keeps track of the internal representation of + the type. +*/ + +/// \addtogroup io_funcs_basic +/// @{ + + +/// WriteBasicType is the name of the write function for bool, integer types, +/// and floating-point types. They all throw on error. +template void WriteBasicType(std::ostream &os, bool binary, T t); + +/// ReadBasicType is the name of the read function for bool, integer types, +/// and floating-point types. They all throw on error. +template void ReadBasicType(std::istream &is, bool binary, T *t); + + +// Declare specialization for bool. +template<> +void WriteBasicType(std::ostream &os, bool binary, bool b); + +template <> +void ReadBasicType(std::istream &is, bool binary, bool *b); + +// Declare specializations for float and double. +template<> +void WriteBasicType(std::ostream &os, bool binary, float f); + +template<> +void WriteBasicType(std::ostream &os, bool binary, double f); + +template<> +void ReadBasicType(std::istream &is, bool binary, float *f); + +template<> +void ReadBasicType(std::istream &is, bool binary, double *f); + +// Define ReadBasicType that accepts an "add" parameter to add to +// the destination. Caution: if used in Read functions, be careful +// to initialize the parameters concerned to zero in the default +// constructor. +template +inline void ReadBasicType(std::istream &is, bool binary, T *t, bool add) { + if (!add) { + ReadBasicType(is, binary, t); + } else { + T tmp = T(0); + ReadBasicType(is, binary, &tmp); + *t += tmp; + } +} + +/// Function for writing STL vectors of integer types. +template inline void WriteIntegerVector(std::ostream &os, bool binary, + const std::vector &v); + +/// Function for reading STL vector of integer types. +template inline void ReadIntegerVector(std::istream &is, bool binary, + std::vector *v); + +/// Function for writing STL vectors of pairs of integer types. +template +inline void WriteIntegerPairVector(std::ostream &os, bool binary, + const std::vector > &v); + +/// Function for reading STL vector of pairs of integer types. +template +inline void ReadIntegerPairVector(std::istream &is, bool binary, + std::vector > *v); + +/// The WriteToken functions are for writing nonempty sequences of non-space +/// characters. They are not for general strings. +void WriteToken(std::ostream &os, bool binary, const char *token); +void WriteToken(std::ostream &os, bool binary, const std::string & token); + +/// Peek consumes whitespace (if binary == false) and then returns the peek() +/// value of the stream. +int Peek(std::istream &is, bool binary); + +/// ReadToken gets the next token and puts it in str (exception on failure). If +/// PeekToken() had been previously called, it is possible that the stream had +/// failed to unget the starting '<' character. In this case ReadToken() returns +/// the token string without the leading '<'. You must be prepared to handle +/// this case. ExpectToken() handles this internally, and is not affected. +void ReadToken(std::istream &is, bool binary, std::string *token); + +/// PeekToken will return the first character of the next token, or -1 if end of +/// file. It's the same as Peek(), except if the first character is '<' it will +/// skip over it and will return the next character. It will attempt to unget +/// the '<' so the stream is where it was before you did PeekToken(), however, +/// this is not guaranteed (see ReadToken()). +int PeekToken(std::istream &is, bool binary); + +/// ExpectToken tries to read in the given token, and throws an exception +/// on failure. +void ExpectToken(std::istream &is, bool binary, const char *token); +void ExpectToken(std::istream &is, bool binary, const std::string & token); + +/// ExpectPretty attempts to read the text in "token", but only in non-binary +/// mode. Throws exception on failure. It expects an exact match except that +/// arbitrary whitespace matches arbitrary whitespace. +void ExpectPretty(std::istream &is, bool binary, const char *token); +void ExpectPretty(std::istream &is, bool binary, const std::string & token); + +/// @} end "addtogroup io_funcs_basic" + + +/// InitKaldiOutputStream initializes an opened stream for writing by writing an +/// optional binary header and modifying the floating-point precision; it will +/// typically not be called by users directly. +inline void InitKaldiOutputStream(std::ostream &os, bool binary); + +/// InitKaldiInputStream initializes an opened stream for reading by detecting +/// the binary header and setting the "binary" value appropriately; +/// It will typically not be called by users directly. +inline bool InitKaldiInputStream(std::istream &is, bool *binary); + +} // end namespace kaldi. +#endif // KALDI_BASE_IO_FUNCS_H_ diff --git a/speechx/speechx/kaldi/base/kaldi-common.h b/speechx/speechx/kaldi/base/kaldi-common.h new file mode 100644 index 00000000..264565d1 --- /dev/null +++ b/speechx/speechx/kaldi/base/kaldi-common.h @@ -0,0 +1,41 @@ +// base/kaldi-common.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_COMMON_H_ +#define KALDI_BASE_KALDI_COMMON_H_ 1 + +#include +#include +#include // C string stuff like strcpy +#include +#include +#include +#include +#include +#include +#include + +#include "base/kaldi-utils.h" +#include "base/kaldi-error.h" +#include "base/kaldi-types.h" +#include "base/io-funcs.h" +#include "base/kaldi-math.h" +#include "base/timer.h" + +#endif // KALDI_BASE_KALDI_COMMON_H_ diff --git a/speechx/speechx/kaldi/base/kaldi-error.cc b/speechx/speechx/kaldi/base/kaldi-error.cc new file mode 100644 index 00000000..2dbc7318 --- /dev/null +++ b/speechx/speechx/kaldi/base/kaldi-error.cc @@ -0,0 +1,245 @@ +// base/kaldi-error.cc + +// Copyright 2019 LAIX (Yi Sun) +// Copyright 2019 SmartAction LLC (kkm) +// Copyright 2016 Brno University of Technology (author: Karel Vesely) +// Copyright 2009-2011 Microsoft Corporation; Lukas Burget; Ondrej Glembek + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifdef HAVE_EXECINFO_H +#include // To get stack trace in error messages. +// If this #include fails there is an error in the Makefile, it does not +// support your platform well. Make sure HAVE_EXECINFO_H is undefined, +// and the code will compile. +#ifdef HAVE_CXXABI_H +#include // For name demangling. +// Useful to decode the stack trace, but only used if we have execinfo.h +#endif // HAVE_CXXABI_H +#endif // HAVE_EXECINFO_H + +#include "base/kaldi-common.h" +#include "base/kaldi-error.h" +#include "base/version.h" + +namespace kaldi { + +/***** GLOBAL VARIABLES FOR LOGGING *****/ + +int32 g_kaldi_verbose_level = 0; +static std::string program_name; +static LogHandler log_handler = NULL; + +void SetProgramName(const char *basename) { + // Using the 'static std::string' for the program name is mostly harmless, + // because (a) Kaldi logging is undefined before main(), and (b) no stdc++ + // string implementation has been found in the wild that would not be just + // an empty string when zero-initialized but not yet constructed. + program_name = basename; +} + +/***** HELPER FUNCTIONS *****/ + +// Trim filename to at most 1 trailing directory long. Given a filename like +// "/a/b/c/d/e/f.cc", return "e/f.cc". Support both '/' and '\' as the path +// separator. +static const char *GetShortFileName(const char *path) { + if (path == nullptr) + return ""; + + const char *prev = path, *last = path; + while ((path = std::strpbrk(path, "\\/")) != nullptr) { + ++path; + prev = last; + last = path; + } + return prev; +} + +/***** STACK TRACE *****/ + +namespace internal { +bool LocateSymbolRange(const std::string &trace_name, size_t *begin, + size_t *end) { + // Find the first '_' with leading ' ' or '('. + *begin = std::string::npos; + for (size_t i = 1; i < trace_name.size(); i++) { + if (trace_name[i] != '_') { + continue; + } + if (trace_name[i - 1] == ' ' || trace_name[i - 1] == '(') { + *begin = i; + break; + } + } + if (*begin == std::string::npos) { + return false; + } + *end = trace_name.find_first_of(" +", *begin); + return *end != std::string::npos; +} +} // namespace internal + +#ifdef HAVE_EXECINFO_H +static std::string Demangle(std::string trace_name) { +#ifndef HAVE_CXXABI_H + return trace_name; +#else // HAVE_CXXABI_H + // Try demangle the symbol. We are trying to support the following formats + // produced by different platforms: + // + // Linux: + // ./kaldi-error-test(_ZN5kaldi13UnitTestErrorEv+0xb) [0x804965d] + // + // Mac: + // 0 server 0x000000010f67614d _ZNK5kaldi13MessageLogger10LogMessageEv + 813 + // + // We want to extract the name e.g., '_ZN5kaldi13UnitTestErrorEv' and + // demangle it info a readable name like kaldi::UnitTextError. + size_t begin, end; + if (!internal::LocateSymbolRange(trace_name, &begin, &end)) { + return trace_name; + } + std::string symbol = trace_name.substr(begin, end - begin); + int status; + char *demangled_name = abi::__cxa_demangle(symbol.c_str(), 0, 0, &status); + if (status == 0 && demangled_name != nullptr) { + symbol = demangled_name; + free(demangled_name); + } + return trace_name.substr(0, begin) + symbol + + trace_name.substr(end, std::string::npos); +#endif // HAVE_CXXABI_H +} +#endif // HAVE_EXECINFO_H + +static std::string KaldiGetStackTrace() { + std::string ans; +#ifdef HAVE_EXECINFO_H + const size_t KALDI_MAX_TRACE_SIZE = 50; + const size_t KALDI_MAX_TRACE_PRINT = 50; // Must be even. + // Buffer for the trace. + void *trace[KALDI_MAX_TRACE_SIZE]; + // Get the trace. + size_t size = backtrace(trace, KALDI_MAX_TRACE_SIZE); + // Get the trace symbols. + char **trace_symbol = backtrace_symbols(trace, size); + if (trace_symbol == NULL) + return ans; + + // Compose a human-readable backtrace string. + ans += "[ Stack-Trace: ]\n"; + if (size <= KALDI_MAX_TRACE_PRINT) { + for (size_t i = 0; i < size; i++) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + } else { // Print out first+last (e.g.) 5. + for (size_t i = 0; i < KALDI_MAX_TRACE_PRINT / 2; i++) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + ans += ".\n.\n.\n"; + for (size_t i = size - KALDI_MAX_TRACE_PRINT / 2; i < size; i++) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + if (size == KALDI_MAX_TRACE_SIZE) + ans += ".\n.\n.\n"; // Stack was too long, probably a bug. + } + + // We must free the array of pointers allocated by backtrace_symbols(), + // but not the strings themselves. + free(trace_symbol); +#endif // HAVE_EXECINFO_H + return ans; +} + +/***** KALDI LOGGING *****/ + +MessageLogger::MessageLogger(LogMessageEnvelope::Severity severity, + const char *func, const char *file, int32 line) { + // Obviously, we assume the strings survive the destruction of this object. + envelope_.severity = severity; + envelope_.func = func; + envelope_.file = GetShortFileName(file); // Points inside 'file'. + envelope_.line = line; +} + +void MessageLogger::LogMessage() const { + // Send to the logging handler if provided. + if (log_handler != NULL) { + log_handler(envelope_, GetMessage().c_str()); + return; + } + + // Otherwise, use the default Kaldi logging. + // Build the log-message header. + std::stringstream full_message; + if (envelope_.severity > LogMessageEnvelope::kInfo) { + full_message << "VLOG[" << envelope_.severity << "] ("; + } else { + switch (envelope_.severity) { + case LogMessageEnvelope::kInfo: + full_message << "LOG ("; + break; + case LogMessageEnvelope::kWarning: + full_message << "WARNING ("; + break; + case LogMessageEnvelope::kAssertFailed: + full_message << "ASSERTION_FAILED ("; + break; + case LogMessageEnvelope::kError: + default: // If not the ERROR, it still an error! + full_message << "ERROR ("; + break; + } + } + // Add other info from the envelope and the message text. + full_message << program_name.c_str() << "[" KALDI_VERSION "]" << ':' + << envelope_.func << "():" << envelope_.file << ':' + << envelope_.line << ") " << GetMessage().c_str(); + + // Add stack trace for errors and assertion failures, if available. + if (envelope_.severity < LogMessageEnvelope::kWarning) { + const std::string &stack_trace = KaldiGetStackTrace(); + if (!stack_trace.empty()) { + full_message << "\n\n" << stack_trace; + } + } + + // Print the complete message to stderr. + full_message << "\n"; + std::cerr << full_message.str(); +} + +/***** KALDI ASSERTS *****/ + +void KaldiAssertFailure_(const char *func, const char *file, int32 line, + const char *cond_str) { + MessageLogger::Log() = + MessageLogger(LogMessageEnvelope::kAssertFailed, func, file, line) + << "Assertion failed: (" << cond_str << ")"; + fflush(NULL); // Flush all pending buffers, abort() may not flush stderr. + std::abort(); +} + +/***** THIRD-PARTY LOG-HANDLER *****/ + +LogHandler SetLogHandler(LogHandler handler) { + LogHandler old_handler = log_handler; + log_handler = handler; + return old_handler; +} + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/base/kaldi-error.h b/speechx/speechx/kaldi/base/kaldi-error.h new file mode 100644 index 00000000..a9904a75 --- /dev/null +++ b/speechx/speechx/kaldi/base/kaldi-error.h @@ -0,0 +1,231 @@ +// base/kaldi-error.h + +// Copyright 2019 LAIX (Yi Sun) +// Copyright 2019 SmartAction LLC (kkm) +// Copyright 2016 Brno University of Technology (author: Karel Vesely) +// Copyright 2009-2011 Microsoft Corporation; Ondrej Glembek; Lukas Burget; +// Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_ERROR_H_ +#define KALDI_BASE_KALDI_ERROR_H_ 1 + +#include +#include +#include +#include +#include +#include + +#include "base/kaldi-types.h" +#include "base/kaldi-utils.h" +/* Important that this file does not depend on any other kaldi headers. */ + +#ifdef _MSC_VER +#define __func__ __FUNCTION__ +#endif + +namespace kaldi { + +/// \addtogroup error_group +/// @{ + +/***** PROGRAM NAME AND VERBOSITY LEVEL *****/ + +/// Called by ParseOptions to set base name (no directory) of the executing +/// program. The name is printed in logging code along with every message, +/// because in our scripts, we often mix together the stderr of many programs. +/// This function is very thread-unsafe. +void SetProgramName(const char *basename); + +/// This is set by util/parse-options.{h,cc} if you set --verbose=? option. +/// Do not use directly, prefer {Get,Set}VerboseLevel(). +extern int32 g_kaldi_verbose_level; + +/// Get verbosity level, usually set via command line '--verbose=' switch. +inline int32 GetVerboseLevel() { return g_kaldi_verbose_level; } + +/// This should be rarely used, except by programs using Kaldi as library; +/// command-line programs set the verbose level automatically from ParseOptions. +inline void SetVerboseLevel(int32 i) { g_kaldi_verbose_level = i; } + +/***** KALDI LOGGING *****/ + +/// Log message severity and source location info. +struct LogMessageEnvelope { + /// Message severity. In addition to these levels, positive values (1 to 6) + /// specify verbose logging level. Verbose messages are produced only when + /// SetVerboseLevel() has been called to set logging level to at least the + /// corresponding value. + enum Severity { + kAssertFailed = -3, //!< Assertion failure. abort() will be called. + kError = -2, //!< Fatal error. KaldiFatalError will be thrown. + kWarning = -1, //!< Indicates a recoverable but abnormal condition. + kInfo = 0, //!< Informational message. + }; + int severity; //!< A Severity value, or positive verbosity level. + const char *func; //!< Name of the function invoking the logging. + const char *file; //!< Source file name with up to 1 leading directory. + int32 line; // MessageLogger &operator<<(const T &val) { + ss_ << val; + return *this; + } + + // When assigned a MessageLogger, log its contents. + struct Log final { + void operator=(const MessageLogger &logger) { logger.LogMessage(); } + }; + + // When assigned a MessageLogger, log its contents and then throw + // a KaldiFatalError. + struct LogAndThrow final { + [[noreturn]] void operator=(const MessageLogger &logger) { + logger.LogMessage(); + throw KaldiFatalError(logger.GetMessage()); + } + }; + +private: + std::string GetMessage() const { return ss_.str(); } + void LogMessage() const; + + LogMessageEnvelope envelope_; + std::ostringstream ss_; +}; + +// Logging macros. +#define KALDI_ERR \ + ::kaldi::MessageLogger::LogAndThrow() = ::kaldi::MessageLogger( \ + ::kaldi::LogMessageEnvelope::kError, __func__, __FILE__, __LINE__) +#define KALDI_WARN \ + ::kaldi::MessageLogger::Log() = ::kaldi::MessageLogger( \ + ::kaldi::LogMessageEnvelope::kWarning, __func__, __FILE__, __LINE__) +#define KALDI_LOG \ + ::kaldi::MessageLogger::Log() = ::kaldi::MessageLogger( \ + ::kaldi::LogMessageEnvelope::kInfo, __func__, __FILE__, __LINE__) +#define KALDI_VLOG(v) \ + if ((v) <= ::kaldi::GetVerboseLevel()) \ + ::kaldi::MessageLogger::Log() = \ + ::kaldi::MessageLogger((::kaldi::LogMessageEnvelope::Severity)(v), \ + __func__, __FILE__, __LINE__) + +/***** KALDI ASSERTS *****/ + +[[noreturn]] void KaldiAssertFailure_(const char *func, const char *file, + int32 line, const char *cond_str); + +// Note on KALDI_ASSERT and KALDI_PARANOID_ASSERT: +// +// A single block {} around if /else does not work, because it causes +// syntax error (unmatched else block) in the following code: +// +// if (condition) +// KALDI_ASSERT(condition2); +// else +// SomethingElse(); +// +// do {} while(0) -- note there is no semicolon at the end! -- works nicely, +// and compilers will be able to optimize the loop away (as the condition +// is always false). +// +// Also see KALDI_COMPILE_TIME_ASSERT, defined in base/kaldi-utils.h, and +// KALDI_ASSERT_IS_INTEGER_TYPE and KALDI_ASSERT_IS_FLOATING_TYPE, also defined +// there. +#ifndef NDEBUG +#define KALDI_ASSERT(cond) \ + do { \ + if (cond) \ + (void)0; \ + else \ + ::kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); \ + } while (0) +#else +#define KALDI_ASSERT(cond) (void)0 +#endif + +// Some more expensive asserts only checked if this defined. +#ifdef KALDI_PARANOID +#define KALDI_PARANOID_ASSERT(cond) \ + do { \ + if (cond) \ + (void)0; \ + else \ + ::kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); \ + } while (0) +#else +#define KALDI_PARANOID_ASSERT(cond) (void)0 +#endif + +/***** THIRD-PARTY LOG-HANDLER *****/ + +/// Type of third-party logging function. +typedef void (*LogHandler)(const LogMessageEnvelope &envelope, + const char *message); + +/// Set logging handler. If called with a non-NULL function pointer, the +/// function pointed by it is called to send messages to a caller-provided log. +/// If called with a NULL pointer, restores default Kaldi error logging to +/// stderr. This function is obviously not thread safe; the log handler must be. +/// Returns a previously set logging handler pointer, or NULL. +LogHandler SetLogHandler(LogHandler); + +/// @} end "addtogroup error_group" + +// Functions within internal is exported for testing only, do not use. +namespace internal { +bool LocateSymbolRange(const std::string &trace_name, size_t *begin, + size_t *end); +} // namespace internal +} // namespace kaldi + +#endif // KALDI_BASE_KALDI_ERROR_H_ diff --git a/speechx/speechx/kaldi/base/kaldi-math.cc b/speechx/speechx/kaldi/base/kaldi-math.cc new file mode 100644 index 00000000..484c80d4 --- /dev/null +++ b/speechx/speechx/kaldi/base/kaldi-math.cc @@ -0,0 +1,162 @@ +// base/kaldi-math.cc + +// Copyright 2009-2011 Microsoft Corporation; Yanmin Qian; +// Saarland University; Jan Silovsky + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-math.h" +#ifndef _MSC_VER +#include +#include +#endif +#include +#include + +namespace kaldi { +// These routines are tested in matrix/matrix-test.cc + +int32 RoundUpToNearestPowerOfTwo(int32 n) { + KALDI_ASSERT(n > 0); + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + return n+1; +} + +static std::mutex _RandMutex; + +int Rand(struct RandomState* state) { +#if !defined(_POSIX_THREAD_SAFE_FUNCTIONS) + // On Windows and Cygwin, just call Rand() + return rand(); +#else + if (state) { + return rand_r(&(state->seed)); + } else { + std::lock_guard lock(_RandMutex); + return rand(); + } +#endif +} + +RandomState::RandomState() { + // we initialize it as Rand() + 27437 instead of just Rand(), because on some + // systems, e.g. at the very least Mac OSX Yosemite and later, it seems to be + // the case that rand_r when initialized with rand() will give you the exact + // same sequence of numbers that rand() will give if you keep calling rand() + // after that initial call. This can cause problems with repeated sequences. + // For example if you initialize two RandomState structs one after the other + // without calling rand() in between, they would give you the same sequence + // offset by one (if we didn't have the "+ 27437" in the code). 27437 is just + // a randomly chosen prime number. + seed = Rand() + 27437; +} + +bool WithProb(BaseFloat prob, struct RandomState* state) { + KALDI_ASSERT(prob >= 0 && prob <= 1.1); // prob should be <= 1.0, + // but we allow slightly larger values that could arise from roundoff in + // previous calculations. + KALDI_COMPILE_TIME_ASSERT(RAND_MAX > 128 * 128); + if (prob == 0) return false; + else if (prob == 1.0) return true; + else if (prob * RAND_MAX < 128.0) { + // prob is very small but nonzero, and the "main algorithm" + // wouldn't work that well. So: with probability 1/128, we + // return WithProb (prob * 128), else return false. + if (Rand(state) < RAND_MAX / 128) { // with probability 128... + // Note: we know that prob * 128.0 < 1.0, because + // we asserted RAND_MAX > 128 * 128. + return WithProb(prob * 128.0); + } else { + return false; + } + } else { + return (Rand(state) < ((RAND_MAX + static_cast(1.0)) * prob)); + } +} + +int32 RandInt(int32 min_val, int32 max_val, struct RandomState* state) { + // This is not exact. + KALDI_ASSERT(max_val >= min_val); + if (max_val == min_val) return min_val; + +#ifdef _MSC_VER + // RAND_MAX is quite small on Windows -> may need to handle larger numbers. + if (RAND_MAX > (max_val-min_val)*8) { + // *8 to avoid large inaccuracies in probability, from the modulus... + return min_val + + ((unsigned int)Rand(state) % (unsigned int)(max_val+1-min_val)); + } else { + if ((unsigned int)(RAND_MAX*RAND_MAX) > + (unsigned int)((max_val+1-min_val)*8)) { + // *8 to avoid inaccuracies in probability, from the modulus... + return min_val + ( (unsigned int)( (Rand(state)+RAND_MAX*Rand(state))) + % (unsigned int)(max_val+1-min_val)); + } else { + KALDI_ERR << "rand_int failed because we do not support such large " + "random numbers. (Extend this function)."; + } + } +#else + return min_val + + (static_cast(Rand(state)) % static_cast(max_val+1-min_val)); +#endif +} + +// Returns poisson-distributed random number. +// Take care: this takes time proportional +// to lambda. Faster algorithms exist but are more complex. +int32 RandPoisson(float lambda, struct RandomState* state) { + // Knuth's algorithm. + KALDI_ASSERT(lambda >= 0); + float L = expf(-lambda), p = 1.0; + int32 k = 0; + do { + k++; + float u = RandUniform(state); + p *= u; + } while (p > L); + return k-1; +} + +void RandGauss2(float *a, float *b, RandomState *state) { + KALDI_ASSERT(a); + KALDI_ASSERT(b); + float u1 = RandUniform(state); + float u2 = RandUniform(state); + u1 = sqrtf(-2.0f * logf(u1)); + u2 = 2.0f * M_PI * u2; + *a = u1 * cosf(u2); + *b = u1 * sinf(u2); +} + +void RandGauss2(double *a, double *b, RandomState *state) { + KALDI_ASSERT(a); + KALDI_ASSERT(b); + float a_float, b_float; + // Just because we're using doubles doesn't mean we need super-high-quality + // random numbers, so we just use the floating-point version internally. + RandGauss2(&a_float, &b_float, state); + *a = a_float; + *b = b_float; +} + + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/base/kaldi-math.h b/speechx/speechx/kaldi/base/kaldi-math.h new file mode 100644 index 00000000..93c265ee --- /dev/null +++ b/speechx/speechx/kaldi/base/kaldi-math.h @@ -0,0 +1,363 @@ +// base/kaldi-math.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian; +// Jan Silovsky; Saarland University +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_MATH_H_ +#define KALDI_BASE_KALDI_MATH_H_ 1 + +#ifdef _MSC_VER +#include +#endif + +#include +#include +#include + +#include "base/kaldi-types.h" +#include "base/kaldi-common.h" + + +#ifndef DBL_EPSILON +#define DBL_EPSILON 2.2204460492503131e-16 +#endif +#ifndef FLT_EPSILON +#define FLT_EPSILON 1.19209290e-7f +#endif + +#ifndef M_PI +#define M_PI 3.1415926535897932384626433832795 +#endif + +#ifndef M_SQRT2 +#define M_SQRT2 1.4142135623730950488016887 +#endif + +#ifndef M_2PI +#define M_2PI 6.283185307179586476925286766559005 +#endif + +#ifndef M_SQRT1_2 +#define M_SQRT1_2 0.7071067811865475244008443621048490 +#endif + +#ifndef M_LOG_2PI +#define M_LOG_2PI 1.8378770664093454835606594728112 +#endif + +#ifndef M_LN2 +#define M_LN2 0.693147180559945309417232121458 +#endif + +#ifndef M_LN10 +#define M_LN10 2.302585092994045684017991454684 +#endif + + +#define KALDI_ISNAN std::isnan +#define KALDI_ISINF std::isinf +#define KALDI_ISFINITE(x) std::isfinite(x) + +#if !defined(KALDI_SQR) +# define KALDI_SQR(x) ((x) * (x)) +#endif + +namespace kaldi { + +#if !defined(_MSC_VER) || (_MSC_VER >= 1900) +inline double Exp(double x) { return exp(x); } +#ifndef KALDI_NO_EXPF +inline float Exp(float x) { return expf(x); } +#else +inline float Exp(float x) { return exp(static_cast(x)); } +#endif // KALDI_NO_EXPF +#else +inline double Exp(double x) { return exp(x); } +#if !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64) +// Microsoft CL v18.0 buggy 64-bit implementation of +// expf() incorrectly returns -inf for exp(-inf). +inline float Exp(float x) { return exp(static_cast(x)); } +#else +inline float Exp(float x) { return expf(x); } +#endif // !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64) +#endif // !defined(_MSC_VER) || (_MSC_VER >= 1900) + +inline double Log(double x) { return log(x); } +inline float Log(float x) { return logf(x); } + +#if !defined(_MSC_VER) || (_MSC_VER >= 1700) +inline double Log1p(double x) { return log1p(x); } +inline float Log1p(float x) { return log1pf(x); } +#else +inline double Log1p(double x) { + const double cutoff = 1.0e-08; + if (x < cutoff) + return x - 0.5 * x * x; + else + return Log(1.0 + x); +} + +inline float Log1p(float x) { + const float cutoff = 1.0e-07; + if (x < cutoff) + return x - 0.5 * x * x; + else + return Log(1.0 + x); +} +#endif + +static const double kMinLogDiffDouble = Log(DBL_EPSILON); // negative! +static const float kMinLogDiffFloat = Log(FLT_EPSILON); // negative! + +// -infinity +const float kLogZeroFloat = -std::numeric_limits::infinity(); +const double kLogZeroDouble = -std::numeric_limits::infinity(); +const BaseFloat kLogZeroBaseFloat = -std::numeric_limits::infinity(); + +// Returns a random integer between 0 and RAND_MAX, inclusive +int Rand(struct RandomState* state = NULL); + +// State for thread-safe random number generator +struct RandomState { + RandomState(); + unsigned seed; +}; + +// Returns a random integer between first and last inclusive. +int32 RandInt(int32 first, int32 last, struct RandomState* state = NULL); + +// Returns true with probability "prob", +bool WithProb(BaseFloat prob, struct RandomState* state = NULL); +// with 0 <= prob <= 1 [we check this]. +// Internally calls Rand(). This function is carefully implemented so +// that it should work even if prob is very small. + +/// Returns a random number strictly between 0 and 1. +inline float RandUniform(struct RandomState* state = NULL) { + return static_cast((Rand(state) + 1.0) / (RAND_MAX+2.0)); +} + +inline float RandGauss(struct RandomState* state = NULL) { + return static_cast(sqrtf (-2 * Log(RandUniform(state))) + * cosf(2*M_PI*RandUniform(state))); +} + +// Returns poisson-distributed random number. Uses Knuth's algorithm. +// Take care: this takes time proportional +// to lambda. Faster algorithms exist but are more complex. +int32 RandPoisson(float lambda, struct RandomState* state = NULL); + +// Returns a pair of gaussian random numbers. Uses Box-Muller transform +void RandGauss2(float *a, float *b, RandomState *state = NULL); +void RandGauss2(double *a, double *b, RandomState *state = NULL); + +// Also see Vector::RandCategorical(). + +// This is a randomized pruning mechanism that preserves expectations, +// that we typically use to prune posteriors. +template +inline Float RandPrune(Float post, BaseFloat prune_thresh, + struct RandomState* state = NULL) { + KALDI_ASSERT(prune_thresh >= 0.0); + if (post == 0.0 || std::abs(post) >= prune_thresh) + return post; + return (post >= 0 ? 1.0 : -1.0) * + (RandUniform(state) <= fabs(post)/prune_thresh ? prune_thresh : 0.0); +} + +// returns log(exp(x) + exp(y)). +inline double LogAdd(double x, double y) { + double diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffDouble) { + double res; + res = x + Log1p(Exp(diff)); + return res; + } else { + return x; // return the larger one. + } +} + + +// returns log(exp(x) + exp(y)). +inline float LogAdd(float x, float y) { + float diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffFloat) { + float res; + res = x + Log1p(Exp(diff)); + return res; + } else { + return x; // return the larger one. + } +} + + +// returns log(exp(x) - exp(y)). +inline double LogSub(double x, double y) { + if (y >= x) { // Throws exception if y>=x. + if (y == x) + return kLogZeroDouble; + else + KALDI_ERR << "Cannot subtract a larger from a smaller number."; + } + + double diff = y - x; // Will be negative. + double res = x + Log(1.0 - Exp(diff)); + + // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision + if (KALDI_ISNAN(res)) + return kLogZeroDouble; + return res; +} + + +// returns log(exp(x) - exp(y)). +inline float LogSub(float x, float y) { + if (y >= x) { // Throws exception if y>=x. + if (y == x) + return kLogZeroDouble; + else + KALDI_ERR << "Cannot subtract a larger from a smaller number."; + } + + float diff = y - x; // Will be negative. + float res = x + Log(1.0f - Exp(diff)); + + // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision + if (KALDI_ISNAN(res)) + return kLogZeroFloat; + return res; +} + +/// return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)). +static inline bool ApproxEqual(float a, float b, + float relative_tolerance = 0.001) { + // a==b handles infinities. + if (a == b) return true; + float diff = std::abs(a-b); + if (diff == std::numeric_limits::infinity() + || diff != diff) return false; // diff is +inf or nan. + return (diff <= relative_tolerance*(std::abs(a)+std::abs(b))); +} + +/// assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b)) +static inline void AssertEqual(float a, float b, + float relative_tolerance = 0.001) { + // a==b handles infinities. + KALDI_ASSERT(ApproxEqual(a, b, relative_tolerance)); +} + + +// RoundUpToNearestPowerOfTwo does the obvious thing. It crashes if n <= 0. +int32 RoundUpToNearestPowerOfTwo(int32 n); + +/// Returns a / b, rounding towards negative infinity in all cases. +static inline int32 DivideRoundingDown(int32 a, int32 b) { + KALDI_ASSERT(b != 0); + if (a * b >= 0) + return a / b; + else if (a < 0) + return (a - b + 1) / b; + else + return (a - b - 1) / b; +} + +template I Gcd(I m, I n) { + if (m == 0 || n == 0) { + if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. + KALDI_ERR << "Undefined GCD since m = 0, n = 0."; + } + return (m == 0 ? (n > 0 ? n : -n) : ( m > 0 ? m : -m)); + // return absolute value of whichever is nonzero + } + // could use compile-time assertion + // but involves messing with complex template stuff. + KALDI_ASSERT(std::numeric_limits::is_integer); + while (1) { + m %= n; + if (m == 0) return (n > 0 ? n : -n); + n %= m; + if (n == 0) return (m > 0 ? m : -m); + } +} + +/// Returns the least common multiple of two integers. Will +/// crash unless the inputs are positive. +template I Lcm(I m, I n) { + KALDI_ASSERT(m > 0 && n > 0); + I gcd = Gcd(m, n); + return gcd * (m/gcd) * (n/gcd); +} + + +template void Factorize(I m, std::vector *factors) { + // Splits a number into its prime factors, in sorted order from + // least to greatest, with duplication. A very inefficient + // algorithm, which is mainly intended for use in the + // mixed-radix FFT computation (where we assume most factors + // are small). + KALDI_ASSERT(factors != NULL); + KALDI_ASSERT(m >= 1); // Doesn't work for zero or negative numbers. + factors->clear(); + I small_factors[10] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29 }; + + // First try small factors. + for (I i = 0; i < 10; i++) { + if (m == 1) return; // We're done. + while (m % small_factors[i] == 0) { + m /= small_factors[i]; + factors->push_back(small_factors[i]); + } + } + // Next try all odd numbers starting from 31. + for (I j = 31;; j += 2) { + if (m == 1) return; + while (m % j == 0) { + m /= j; + factors->push_back(j); + } + } +} + +inline double Hypot(double x, double y) { return hypot(x, y); } +inline float Hypot(float x, float y) { return hypotf(x, y); } + + + + +} // namespace kaldi + + +#endif // KALDI_BASE_KALDI_MATH_H_ diff --git a/speechx/speechx/kaldi/base/kaldi-types.h b/speechx/speechx/kaldi/base/kaldi-types.h new file mode 100644 index 00000000..16a1a5b9 --- /dev/null +++ b/speechx/speechx/kaldi/base/kaldi-types.h @@ -0,0 +1,76 @@ +// base/kaldi-types.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_TYPES_H_ +#define KALDI_BASE_KALDI_TYPES_H_ 1 + +namespace kaldi { +// TYPEDEFS .................................................................. +#if (KALDI_DOUBLEPRECISION != 0) +typedef double BaseFloat; +#else +typedef float BaseFloat; +#endif +} + +#ifdef _MSC_VER +#include +#define ssize_t SSIZE_T +#endif + +// we can do this a different way if some platform +// we find in the future lacks stdint.h +#include + +// for discussion on what to do if you need compile kaldi +// without OpenFST, see the bottom of this this file +/* +#include + +namespace kaldi { + using ::int16; + using ::int32; + using ::int64; + using ::uint16; + using ::uint32; + using ::uint64; + typedef float float32; + typedef double double64; +} // end namespace kaldi +*/ +// In a theoretical case you decide compile Kaldi without the OpenFST +// comment the previous namespace statement and uncomment the following + +namespace kaldi { + typedef int8_t int8; + typedef int16_t int16; + typedef int32_t int32; + typedef int64_t int64; + + typedef uint8_t uint8; + typedef uint16_t uint16; + typedef uint32_t uint32; + typedef uint64_t uint64; + typedef float float32; + typedef double double64; +} // end namespace kaldi + + +#endif // KALDI_BASE_KALDI_TYPES_H_ diff --git a/speechx/speechx/kaldi/base/kaldi-utils.cc b/speechx/speechx/kaldi/base/kaldi-utils.cc new file mode 100644 index 00000000..432da426b --- /dev/null +++ b/speechx/speechx/kaldi/base/kaldi-utils.cc @@ -0,0 +1,55 @@ +// base/kaldi-utils.cc +// Copyright 2009-2011 Karel Vesely; Yanmin Qian; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifdef _WIN32_WINNT_WIN8 +#include +#elif defined(_WIN32) || defined(_MSC_VER) || defined(MINGW) +#include +#if defined(_MSC_VER) && _MSC_VER < 1900 +#define snprintf _snprintf +#endif /* _MSC_VER < 1900 */ +#else +#include +#endif + +#include +#include "base/kaldi-common.h" + + +namespace kaldi { + +std::string CharToString(const char &c) { + char buf[20]; + if (std::isprint(c)) + snprintf(buf, sizeof(buf), "\'%c\'", c); + else + snprintf(buf, sizeof(buf), "[character %d]", static_cast(c)); + return (std::string) buf; +} + +void Sleep(float seconds) { +#if defined(_MSC_VER) || defined(MINGW) + ::Sleep(static_cast(seconds * 1000.0)); +#elif defined(__CYGWIN__) + sleep(static_cast(seconds)); +#else + usleep(static_cast(seconds * 1000000.0)); +#endif +} + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/base/kaldi-utils.h b/speechx/speechx/kaldi/base/kaldi-utils.h new file mode 100644 index 00000000..c9d6fd95 --- /dev/null +++ b/speechx/speechx/kaldi/base/kaldi-utils.h @@ -0,0 +1,155 @@ +// base/kaldi-utils.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; +// Saarland University; Karel Vesely; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_UTILS_H_ +#define KALDI_BASE_KALDI_UTILS_H_ 1 + +#if defined(_MSC_VER) +# define WIN32_LEAN_AND_MEAN +# define NOMINMAX +# include +#endif + +#ifdef _MSC_VER +#include +#define unlink _unlink +#else +#include +#endif + +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4056 4305 4800 4267 4996 4756 4661) +#if _MSC_VER < 1400 +#define __restrict__ +#else +#define __restrict__ __restrict +#endif +#endif + +#if defined(_MSC_VER) +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (*(pp_orig) = _aligned_malloc(size, align)) +# define KALDI_MEMALIGN_FREE(x) _aligned_free(x) +#elif defined(__CYGWIN__) +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (*(pp_orig) = aligned_alloc(align, size)) +# define KALDI_MEMALIGN_FREE(x) free(x) +#else +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (!posix_memalign(pp_orig, align, size) ? *(pp_orig) : NULL) +# define KALDI_MEMALIGN_FREE(x) free(x) +#endif + +#ifdef __ICC +#pragma warning(disable: 383) // ICPC remark we don't want. +#pragma warning(disable: 810) // ICPC remark we don't want. +#pragma warning(disable: 981) // ICPC remark we don't want. +#pragma warning(disable: 1418) // ICPC remark we don't want. +#pragma warning(disable: 444) // ICPC remark we don't want. +#pragma warning(disable: 869) // ICPC remark we don't want. +#pragma warning(disable: 1287) // ICPC remark we don't want. +#pragma warning(disable: 279) // ICPC remark we don't want. +#pragma warning(disable: 981) // ICPC remark we don't want. +#endif + + +namespace kaldi { + + +// CharToString prints the character in a human-readable form, for debugging. +std::string CharToString(const char &c); + + +inline int MachineIsLittleEndian() { + int check = 1; + return (*reinterpret_cast(&check) != 0); +} + +// This function kaldi::Sleep() provides a portable way +// to sleep for a possibly fractional +// number of seconds. On Windows it's only accurate to microseconds. +void Sleep(float seconds); +} + +#define KALDI_SWAP8(a) { \ + int t = (reinterpret_cast(&a))[0];\ + (reinterpret_cast(&a))[0]=(reinterpret_cast(&a))[7];\ + (reinterpret_cast(&a))[7]=t;\ + t = (reinterpret_cast(&a))[1];\ + (reinterpret_cast(&a))[1]=(reinterpret_cast(&a))[6];\ + (reinterpret_cast(&a))[6]=t;\ + t = (reinterpret_cast(&a))[2];\ + (reinterpret_cast(&a))[2]=(reinterpret_cast(&a))[5];\ + (reinterpret_cast(&a))[5]=t;\ + t = (reinterpret_cast(&a))[3];\ + (reinterpret_cast(&a))[3]=(reinterpret_cast(&a))[4];\ + (reinterpret_cast(&a))[4]=t;} +#define KALDI_SWAP4(a) { \ + int t = (reinterpret_cast(&a))[0];\ + (reinterpret_cast(&a))[0]=(reinterpret_cast(&a))[3];\ + (reinterpret_cast(&a))[3]=t;\ + t = (reinterpret_cast(&a))[1];\ + (reinterpret_cast(&a))[1]=(reinterpret_cast(&a))[2];\ + (reinterpret_cast(&a))[2]=t;} +#define KALDI_SWAP2(a) { \ + int t = (reinterpret_cast(&a))[0];\ + (reinterpret_cast(&a))[0]=(reinterpret_cast(&a))[1];\ + (reinterpret_cast(&a))[1]=t;} + + +// Makes copy constructor and operator= private. +#define KALDI_DISALLOW_COPY_AND_ASSIGN(type) \ + type(const type&); \ + void operator = (const type&) + +template class KaldiCompileTimeAssert { }; +template<> class KaldiCompileTimeAssert { + public: + static inline void Check() { } +}; + +#define KALDI_COMPILE_TIME_ASSERT(b) KaldiCompileTimeAssert<(b)>::Check() + +#define KALDI_ASSERT_IS_INTEGER_TYPE(I) \ + KaldiCompileTimeAssert::is_specialized \ + && std::numeric_limits::is_integer>::Check() + +#define KALDI_ASSERT_IS_FLOATING_TYPE(F) \ + KaldiCompileTimeAssert::is_specialized \ + && !std::numeric_limits::is_integer>::Check() + +#if defined(_MSC_VER) +#define KALDI_STRCASECMP _stricmp +#elif defined(__CYGWIN__) +#include +#define KALDI_STRCASECMP strcasecmp +#else +#define KALDI_STRCASECMP strcasecmp +#endif +#ifdef _MSC_VER +# define KALDI_STRTOLL(cur_cstr, end_cstr) _strtoi64(cur_cstr, end_cstr, 10); +#else +# define KALDI_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); +#endif + +#endif // KALDI_BASE_KALDI_UTILS_H_ diff --git a/speechx/speechx/kaldi/base/timer.cc b/speechx/speechx/kaldi/base/timer.cc new file mode 100644 index 00000000..ce4ef292 --- /dev/null +++ b/speechx/speechx/kaldi/base/timer.cc @@ -0,0 +1,85 @@ +// base/timer.cc + +// Copyright 2018 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/timer.h" +#include "base/kaldi-error.h" +#include +#include +#include +#include + +namespace kaldi { + +class ProfileStats { + public: + void AccStats(const char *function_name, double elapsed) { + std::unordered_map::iterator + iter = map_.find(function_name); + if (iter == map_.end()) { + map_[function_name] = ProfileStatsEntry(function_name); + map_[function_name].total_time = elapsed; + } else { + iter->second.total_time += elapsed; + } + } + ~ProfileStats() { + // This map makes sure we agglomerate the time if there were any duplicate + // addresses of strings. + std::unordered_map total_time; + for (auto iter = map_.begin(); iter != map_.end(); iter++) + total_time[iter->second.name] += iter->second.total_time; + + ReverseSecondComparator comp; + std::vector > pairs(total_time.begin(), + total_time.end()); + std::sort(pairs.begin(), pairs.end(), comp); + for (size_t i = 0; i < pairs.size(); i++) { + KALDI_LOG << "Time taken in " << pairs[i].first << " is " + << std::fixed << std::setprecision(2) << pairs[i].second << "s."; + } + } + private: + + struct ProfileStatsEntry { + std::string name; + double total_time; + ProfileStatsEntry() { } + ProfileStatsEntry(const char *name): name(name) { } + }; + + struct ReverseSecondComparator { + bool operator () (const std::pair &a, + const std::pair &b) { + return a.second > b.second; + } + }; + + // Note: this map is keyed on the address of the string, there is no proper + // hash function. The assumption is that the strings are compile-time + // constants. + std::unordered_map map_; +}; + +ProfileStats g_profile_stats; + +Profiler::~Profiler() { + g_profile_stats.AccStats(name_, tim_.Elapsed()); +} + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/base/timer.h b/speechx/speechx/kaldi/base/timer.h new file mode 100644 index 00000000..0e033766 --- /dev/null +++ b/speechx/speechx/kaldi/base/timer.h @@ -0,0 +1,115 @@ +// base/timer.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_BASE_TIMER_H_ +#define KALDI_BASE_TIMER_H_ + +#include "base/kaldi-utils.h" +#include "base/kaldi-error.h" + + +#if defined(_MSC_VER) || defined(MINGW) + +namespace kaldi { +class Timer { + public: + Timer() { Reset(); } + + // You can initialize with bool to control whether or not you want the time to + // be set when the object is created. + explicit Timer(bool set_timer) { if (set_timer) Reset(); } + + void Reset() { + QueryPerformanceCounter(&time_start_); + } + double Elapsed() const { + LARGE_INTEGER time_end; + LARGE_INTEGER freq; + QueryPerformanceCounter(&time_end); + + if (QueryPerformanceFrequency(&freq) == 0) { + // Hardware does not support this. + return 0.0; + } + return (static_cast(time_end.QuadPart) - + static_cast(time_start_.QuadPart)) / + (static_cast(freq.QuadPart)); + } + private: + LARGE_INTEGER time_start_; +}; + + +#else +#include +#include + +namespace kaldi { +class Timer { + public: + Timer() { Reset(); } + + // You can initialize with bool to control whether or not you want the time to + // be set when the object is created. + explicit Timer(bool set_timer) { if (set_timer) Reset(); } + + void Reset() { gettimeofday(&this->time_start_, &time_zone_); } + + /// Returns time in seconds. + double Elapsed() const { + struct timeval time_end; + struct timezone time_zone; + gettimeofday(&time_end, &time_zone); + double t1, t2; + t1 = static_cast(time_start_.tv_sec) + + static_cast(time_start_.tv_usec)/(1000*1000); + t2 = static_cast(time_end.tv_sec) + + static_cast(time_end.tv_usec)/(1000*1000); + return t2-t1; + } + + private: + struct timeval time_start_; + struct timezone time_zone_; +}; + +#endif + +class Profiler { + public: + // Caution: the 'const char' should always be a string constant; for speed, + // internally the profiling code uses the address of it as a lookup key. + Profiler(const char *function_name): name_(function_name) { } + ~Profiler(); + private: + Timer tim_; + const char *name_; +}; + +// To add timing info for a function, you just put +// KALDI_PROFILE; +// at the beginning of the function. Caution: this doesn't +// include the class name. +#define KALDI_PROFILE Profiler _profiler(__func__) + + + +} // namespace kaldi + + +#endif // KALDI_BASE_TIMER_H_ diff --git a/speechx/speechx/kaldi/base/version.h b/speechx/speechx/kaldi/base/version.h new file mode 100644 index 00000000..a79a5758 --- /dev/null +++ b/speechx/speechx/kaldi/base/version.h @@ -0,0 +1,4 @@ +// This file was automatically created by ./get_version.sh. +// It is only included by ./kaldi-error.cc. +#define KALDI_VERSION "5.5.544~2-f21d7" +#define KALDI_GIT_HEAD "f21d7e768635ca98aeeb43f30e2c6a9f14ab8f0f" diff --git a/speechx/speechx/kaldi/feat/CMakeLists.txt b/speechx/speechx/kaldi/feat/CMakeLists.txt new file mode 100644 index 00000000..8b914962 --- /dev/null +++ b/speechx/speechx/kaldi/feat/CMakeLists.txt @@ -0,0 +1,19 @@ +add_library(kaldi-mfcc + feature-mfcc.cc +) +target_link_libraries(kaldi-mfcc PUBLIC kaldi-feat-common) + +add_library(fbank + feature-fbank.cc +) +target_link_libraries(fbank PUBLIC kaldi-feat-common) + +add_library(kaldi-feat-common + wave-reader.cc + signal.cc + feature-functions.cc + feature-window.cc + resample.cc + mel-computations.cc +) +target_link_libraries(kaldi-feat-common PUBLIC kaldi-base kaldi-matrix kaldi-util) \ No newline at end of file diff --git a/speechx/speechx/kaldi/feat/feature-common-inl.h b/speechx/speechx/kaldi/feat/feature-common-inl.h new file mode 100644 index 00000000..26127a4d --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-common-inl.h @@ -0,0 +1,99 @@ +// feat/feature-common-inl.h + +// Copyright 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_COMMON_INL_H_ +#define KALDI_FEAT_FEATURE_COMMON_INL_H_ + +#include "feat/resample.h" +// Do not include this file directly. It is included by feat/feature-common.h + +namespace kaldi { + +template +void OfflineFeatureTpl::ComputeFeatures( + const VectorBase &wave, + BaseFloat sample_freq, + BaseFloat vtln_warp, + Matrix *output) { + KALDI_ASSERT(output != NULL); + BaseFloat new_sample_freq = computer_.GetFrameOptions().samp_freq; + if (sample_freq == new_sample_freq) { + Compute(wave, vtln_warp, output); + } else { + if (new_sample_freq < sample_freq && + ! computer_.GetFrameOptions().allow_downsample) + KALDI_ERR << "Waveform and config sample Frequency mismatch: " + << sample_freq << " .vs " << new_sample_freq + << " (use --allow-downsample=true to allow " + << " downsampling the waveform)."; + else if (new_sample_freq > sample_freq && + ! computer_.GetFrameOptions().allow_upsample) + KALDI_ERR << "Waveform and config sample Frequency mismatch: " + << sample_freq << " .vs " << new_sample_freq + << " (use --allow-upsample=true option to allow " + << " upsampling the waveform)."; + // Resample the waveform. + Vector resampled_wave(wave); + ResampleWaveform(sample_freq, wave, + new_sample_freq, &resampled_wave); + Compute(resampled_wave, vtln_warp, output); + } +} + +template +void OfflineFeatureTpl::Compute( + const VectorBase &wave, + BaseFloat vtln_warp, + Matrix *output) { + KALDI_ASSERT(output != NULL); + int32 rows_out = NumFrames(wave.Dim(), computer_.GetFrameOptions()), + cols_out = computer_.Dim(); + if (rows_out == 0) { + output->Resize(0, 0); + return; + } + output->Resize(rows_out, cols_out); + Vector window; // windowed waveform. + bool use_raw_log_energy = computer_.NeedRawLogEnergy(); + for (int32 r = 0; r < rows_out; r++) { // r is frame index. + BaseFloat raw_log_energy = 0.0; + ExtractWindow(0, wave, r, computer_.GetFrameOptions(), + feature_window_function_, &window, + (use_raw_log_energy ? &raw_log_energy : NULL)); + + SubVector output_row(*output, r); + computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row); + } +} + +template +void OfflineFeatureTpl::Compute( + const VectorBase &wave, + BaseFloat vtln_warp, + Matrix *output) const { + OfflineFeatureTpl temp(*this); + // call the non-const version of Compute() on a temporary copy of this object. + // This is a workaround for const-ness that may sometimes be useful in + // multi-threaded code, although it's not optimally efficient. + temp.Compute(wave, vtln_warp, output); +} + +} // end namespace kaldi + +#endif diff --git a/speechx/speechx/kaldi/feat/feature-common.h b/speechx/speechx/kaldi/feat/feature-common.h new file mode 100644 index 00000000..3c2fbd37 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-common.h @@ -0,0 +1,176 @@ +// feat/feature-common.h + +// Copyright 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_COMMON_H_ +#define KALDI_FEAT_FEATURE_COMMON_H_ + +#include +#include +#include "feat/feature-window.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + + +/// This class is only added for documentation, it is not intended to ever be +/// used. +struct ExampleFeatureComputerOptions { + FrameExtractionOptions frame_opts; + // .. more would go here. +}; + +/// This class is only added for documentation, it is not intended to ever be +/// used. It documents the interface of the *Computer classes which wrap the +/// low-level feature extraction. The template argument F of OfflineFeatureTpl must +/// follow this interface. This interface is intended for features such as +/// MFCCs and PLPs which can be computed frame by frame. +class ExampleFeatureComputer { + public: + typedef ExampleFeatureComputerOptions Options; + + /// Returns a reference to the frame-extraction options class, which + /// will be part of our own options class. + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + /// Returns the feature dimension + int32 Dim() const; + + /// Returns true if this function may inspect the raw log-energy of the signal + /// (before windowing and pre-emphasis); it's safe to always return true, but + /// setting it to false enables an optimization. + bool NeedRawLogEnergy() const { return true; } + + /// constructor from options class; it should not store a reference or pointer + /// to the options class but should copy it. + explicit ExampleFeatureComputer(const ExampleFeatureComputerOptions &opts): + opts_(opts) { } + + /// Copy constructor; all of these classes must have one. + ExampleFeatureComputer(const ExampleFeatureComputer &other); + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. + */ + void Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature); + + private: + // disallow assignment. + ExampleFeatureComputer &operator = (const ExampleFeatureComputer &in); + Options opts_; +}; + + +/// This templated class is intended for offline feature extraction, i.e. where +/// you have access to the entire signal at the start. It exists mainly to be +/// drop-in replacement for the old (pre-2016) classes Mfcc, Plp and so on, for +/// use in the offline case. In April 2016 we reorganized the online +/// feature-computation code for greater modularity and to have correct support +/// for the snip-edges=false option. +template +class OfflineFeatureTpl { + public: + typedef typename F::Options Options; + + // Note: feature_window_function_ is the windowing function, which initialized + // using the options class, that we cache at this level. + OfflineFeatureTpl(const Options &opts): + computer_(opts), + feature_window_function_(computer_.GetFrameOptions()) { } + + // Internal (and back-compatibility) interface for computing features, which + // requires that the user has already checked that the sampling frequency + // of the waveform is equal to the sampling frequency specified in + // the frame-extraction options. + void Compute(const VectorBase &wave, + BaseFloat vtln_warp, + Matrix *output); + + // This const version of Compute() is a wrapper that + // calls the non-const version on a temporary object. + // It's less efficient than the non-const version. + void Compute(const VectorBase &wave, + BaseFloat vtln_warp, + Matrix *output) const; + + /** + Computes the features for one file (one sequence of features). + This is the newer interface where you specify the sample frequency + of the input waveform. + @param [in] wave The input waveform + @param [in] sample_freq The sampling frequency with which + 'wave' was sampled. + if sample_freq is higher than the frequency + specified in the config, we will downsample + the waveform, but if lower, it's an error. + @param [in] vtln_warp The VTLN warping factor (will normally + be 1.0) + @param [out] output The matrix of features, where the row-index + is the frame index. + */ + void ComputeFeatures(const VectorBase &wave, + BaseFloat sample_freq, + BaseFloat vtln_warp, + Matrix *output); + + int32 Dim() const { return computer_.Dim(); } + + // Copy constructor. + OfflineFeatureTpl(const OfflineFeatureTpl &other): + computer_(other.computer_), + feature_window_function_(other.feature_window_function_) { } + private: + // Disallow assignment. + OfflineFeatureTpl &operator =(const OfflineFeatureTpl &other); + + F computer_; + FeatureWindowFunction feature_window_function_; +}; + +/// @} End of "addtogroup feat" +} // namespace kaldi + + +#include "feat/feature-common-inl.h" + +#endif // KALDI_FEAT_FEATURE_COMMON_H_ diff --git a/speechx/speechx/kaldi/feat/feature-fbank.cc b/speechx/speechx/kaldi/feat/feature-fbank.cc new file mode 100644 index 00000000..d9ac03e5 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-fbank.cc @@ -0,0 +1,125 @@ +// feat/feature-fbank.cc + +// Copyright 2009-2012 Karel Vesely +// 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "feat/feature-fbank.h" + +namespace kaldi { + +FbankComputer::FbankComputer(const FbankOptions &opts): + opts_(opts), srfft_(NULL) { + if (opts.energy_floor > 0.0) + log_energy_floor_ = Log(opts.energy_floor); + + int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); + if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... + srfft_ = new SplitRadixRealFft(padded_window_size); + + // We'll definitely need the filterbanks info for VTLN warping factor 1.0. + // [note: this call caches it.] + GetMelBanks(1.0); +} + +FbankComputer::FbankComputer(const FbankComputer &other): + opts_(other.opts_), log_energy_floor_(other.log_energy_floor_), + mel_banks_(other.mel_banks_), srfft_(NULL) { + for (std::map::iterator iter = mel_banks_.begin(); + iter != mel_banks_.end(); + ++iter) + iter->second = new MelBanks(*(iter->second)); + if (other.srfft_) + srfft_ = new SplitRadixRealFft(*(other.srfft_)); +} + +FbankComputer::~FbankComputer() { + for (std::map::iterator iter = mel_banks_.begin(); + iter != mel_banks_.end(); ++iter) + delete iter->second; + delete srfft_; +} + +const MelBanks* FbankComputer::GetMelBanks(BaseFloat vtln_warp) { + MelBanks *this_mel_banks = NULL; + std::map::iterator iter = mel_banks_.find(vtln_warp); + if (iter == mel_banks_.end()) { + this_mel_banks = new MelBanks(opts_.mel_opts, + opts_.frame_opts, + vtln_warp); + mel_banks_[vtln_warp] = this_mel_banks; + } else { + this_mel_banks = iter->second; + } + return this_mel_banks; +} + +void FbankComputer::Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature) { + + const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); + + KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && + feature->Dim() == this->Dim()); + + + // Compute energy after window function (not the raw one). + if (opts_.use_energy && !opts_.raw_energy) + signal_raw_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), + std::numeric_limits::epsilon())); + + if (srfft_ != NULL) // Compute FFT using split-radix algorithm. + srfft_->Compute(signal_frame->Data(), true); + else // An alternative algorithm that works for non-powers-of-two. + RealFft(signal_frame, true); + + // Convert the FFT into a power spectrum. + ComputePowerSpectrum(signal_frame); + SubVector power_spectrum(*signal_frame, 0, + signal_frame->Dim() / 2 + 1); + + // Use magnitude instead of power if requested. + if (!opts_.use_power) + power_spectrum.ApplyPow(0.5); + + int32 mel_offset = ((opts_.use_energy && !opts_.htk_compat) ? 1 : 0); + SubVector mel_energies(*feature, + mel_offset, + opts_.mel_opts.num_bins); + + // Sum with mel fiterbanks over the power spectrum + mel_banks.Compute(power_spectrum, &mel_energies); + if (opts_.use_log_fbank) { + // Avoid log of zero (which should be prevented anyway by dithering). + mel_energies.ApplyFloor(std::numeric_limits::epsilon()); + mel_energies.ApplyLog(); // take the log. + } + + // Copy energy as first value (or the last, if htk_compat == true). + if (opts_.use_energy) { + if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) { + signal_raw_log_energy = log_energy_floor_; + } + int32 energy_index = opts_.htk_compat ? opts_.mel_opts.num_bins : 0; + (*feature)(energy_index) = signal_raw_log_energy; + } +} + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-fbank.h b/speechx/speechx/kaldi/feat/feature-fbank.h new file mode 100644 index 00000000..f57d185a --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-fbank.h @@ -0,0 +1,149 @@ +// feat/feature-fbank.h + +// Copyright 2009-2012 Karel Vesely +// 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_FBANK_H_ +#define KALDI_FEAT_FEATURE_FBANK_H_ + +#include +#include + +#include "feat/feature-common.h" +#include "feat/feature-functions.h" +#include "feat/feature-window.h" +#include "feat/mel-computations.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + +/// FbankOptions contains basic options for computing filterbank features. +/// It only includes things that can be done in a "stateless" way, i.e. +/// it does not include energy max-normalization. +/// It does not include delta computation. +struct FbankOptions { + FrameExtractionOptions frame_opts; + MelBanksOptions mel_opts; + bool use_energy; // append an extra dimension with energy to the filter banks + BaseFloat energy_floor; + bool raw_energy; // If true, compute energy before preemphasis and windowing + bool htk_compat; // If true, put energy last (if using energy) + bool use_log_fbank; // if true (default), produce log-filterbank, else linear + bool use_power; // if true (default), use power in filterbank analysis, else magnitude. + + FbankOptions(): mel_opts(23), + // defaults the #mel-banks to 23 for the FBANK computations. + // this seems to be common for 16khz-sampled data, + // but for 8khz-sampled data, 15 may be better. + use_energy(false), + energy_floor(0.0), + raw_energy(true), + htk_compat(false), + use_log_fbank(true), + use_power(true) {} + + void Register(OptionsItf *opts) { + frame_opts.Register(opts); + mel_opts.Register(opts); + opts->Register("use-energy", &use_energy, + "Add an extra dimension with energy to the FBANK output."); + opts->Register("energy-floor", &energy_floor, + "Floor on energy (absolute, not relative) in FBANK computation. " + "Only makes a difference if --use-energy=true; only necessary if " + "--dither=0.0. Suggested values: 0.1 or 1.0"); + opts->Register("raw-energy", &raw_energy, + "If true, compute energy before preemphasis and windowing"); + opts->Register("htk-compat", &htk_compat, "If true, put energy last. " + "Warning: not sufficient to get HTK compatible features (need " + "to change other parameters)."); + opts->Register("use-log-fbank", &use_log_fbank, + "If true, produce log-filterbank, else produce linear."); + opts->Register("use-power", &use_power, + "If true, use power, else use magnitude."); + } +}; + + +/// Class for computing mel-filterbank features; see \ref feat_mfcc for more +/// information. +class FbankComputer { + public: + typedef FbankOptions Options; + + explicit FbankComputer(const FbankOptions &opts); + FbankComputer(const FbankComputer &other); + + int32 Dim() const { + return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); + } + + bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } + + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedsRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. + */ + void Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature); + + ~FbankComputer(); + + private: + const MelBanks *GetMelBanks(BaseFloat vtln_warp); + + + FbankOptions opts_; + BaseFloat log_energy_floor_; + std::map mel_banks_; // BaseFloat is VTLN coefficient. + SplitRadixRealFft *srfft_; + // Disallow assignment. + FbankComputer &operator =(const FbankComputer &other); +}; + +typedef OfflineFeatureTpl Fbank; + +/// @} End of "addtogroup feat" +} // namespace kaldi + + +#endif // KALDI_FEAT_FEATURE_FBANK_H_ diff --git a/speechx/speechx/kaldi/feat/feature-functions.cc b/speechx/speechx/kaldi/feat/feature-functions.cc new file mode 100644 index 00000000..76500ccf --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-functions.cc @@ -0,0 +1,362 @@ +// feat/feature-functions.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) +// 2014 IMSL, PKU-HKUST (author: Wei Shi) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "feat/feature-functions.h" +#include "matrix/matrix-functions.h" + + +namespace kaldi { + +void ComputePowerSpectrum(VectorBase *waveform) { + int32 dim = waveform->Dim(); + + // no, letting it be non-power-of-two for now. + // KALDI_ASSERT(dim > 0 && (dim & (dim-1) == 0)); // make sure a power of two.. actually my FFT code + // does not require this (dan) but this is better in case we use different code [dan]. + + // RealFft(waveform, true); // true == forward (not inverse) FFT; makes no difference here, + // as we just want power spectrum. + + // now we have in waveform, first half of complex spectrum + // it's stored as [real0, realN/2, real1, im1, real2, im2, ...] + int32 half_dim = dim/2; + BaseFloat first_energy = (*waveform)(0) * (*waveform)(0), + last_energy = (*waveform)(1) * (*waveform)(1); // handle this special case + for (int32 i = 1; i < half_dim; i++) { + BaseFloat real = (*waveform)(i*2), im = (*waveform)(i*2 + 1); + (*waveform)(i) = real*real + im*im; + } + (*waveform)(0) = first_energy; + (*waveform)(half_dim) = last_energy; // Will actually never be used, and anyway + // if the signal has been bandlimited sensibly this should be zero. +} + + +DeltaFeatures::DeltaFeatures(const DeltaFeaturesOptions &opts): opts_(opts) { + KALDI_ASSERT(opts.order >= 0 && opts.order < 1000); // just make sure we don't get binary junk. + // opts will normally be 2 or 3. + KALDI_ASSERT(opts.window > 0 && opts.window < 1000); // again, basic sanity check. + // normally the window size will be two. + + scales_.resize(opts.order+1); + scales_[0].Resize(1); + scales_[0](0) = 1.0; // trivial window for 0th order delta [i.e. baseline feats] + + for (int32 i = 1; i <= opts.order; i++) { + Vector &prev_scales = scales_[i-1], + &cur_scales = scales_[i]; + int32 window = opts.window; // this code is designed to still + // work if instead we later make it an array and do opts.window[i-1], + // or something like that. "window" is a parameter specifying delta-window + // width which is actually 2*window + 1. + KALDI_ASSERT(window != 0); + int32 prev_offset = (static_cast(prev_scales.Dim()-1))/2, + cur_offset = prev_offset + window; + cur_scales.Resize(prev_scales.Dim() + 2*window); // also zeros it. + + BaseFloat normalizer = 0.0; + for (int32 j = -window; j <= window; j++) { + normalizer += j*j; + for (int32 k = -prev_offset; k <= prev_offset; k++) { + cur_scales(j+k+cur_offset) += + static_cast(j) * prev_scales(k+prev_offset); + } + } + cur_scales.Scale(1.0 / normalizer); + } +} + +void DeltaFeatures::Process(const MatrixBase &input_feats, + int32 frame, + VectorBase *output_frame) const { + KALDI_ASSERT(frame < input_feats.NumRows()); + int32 num_frames = input_feats.NumRows(), + feat_dim = input_feats.NumCols(); + KALDI_ASSERT(static_cast(output_frame->Dim()) == feat_dim * (opts_.order+1)); + output_frame->SetZero(); + for (int32 i = 0; i <= opts_.order; i++) { + const Vector &scales = scales_[i]; + int32 max_offset = (scales.Dim() - 1) / 2; + SubVector output(*output_frame, i*feat_dim, feat_dim); + for (int32 j = -max_offset; j <= max_offset; j++) { + // if asked to read + int32 offset_frame = frame + j; + if (offset_frame < 0) offset_frame = 0; + else if (offset_frame >= num_frames) + offset_frame = num_frames - 1; + BaseFloat scale = scales(j + max_offset); + if (scale != 0.0) + output.AddVec(scale, input_feats.Row(offset_frame)); + } + } +} + +ShiftedDeltaFeatures::ShiftedDeltaFeatures( + const ShiftedDeltaFeaturesOptions &opts): opts_(opts) { + KALDI_ASSERT(opts.window > 0 && opts.window < 1000); + + // Default window is 1. + int32 window = opts.window; + KALDI_ASSERT(window != 0); + scales_.Resize(1 + 2*window); // also zeros it. + BaseFloat normalizer = 0.0; + for (int32 j = -window; j <= window; j++) { + normalizer += j*j; + scales_(j + window) += static_cast(j); + } + scales_.Scale(1.0 / normalizer); +} + +void ShiftedDeltaFeatures::Process(const MatrixBase &input_feats, + int32 frame, + SubVector *output_frame) const { + KALDI_ASSERT(frame < input_feats.NumRows()); + int32 num_frames = input_feats.NumRows(), + feat_dim = input_feats.NumCols(); + KALDI_ASSERT(static_cast(output_frame->Dim()) + == feat_dim * (opts_.num_blocks + 1)); + output_frame->SetZero(); + + // The original features + SubVector output(*output_frame, 0, feat_dim); + output.AddVec(1.0, input_feats.Row(frame)); + + // Concatenate the delta-blocks. Each block is block_shift + // (usually 3) frames apart. + for (int32 i = 0; i < opts_.num_blocks; i++) { + int32 max_offset = (scales_.Dim() - 1) / 2; + SubVector output(*output_frame, (i + 1) * feat_dim, feat_dim); + for (int32 j = -max_offset; j <= max_offset; j++) { + int32 offset_frame = frame + j + i * opts_.block_shift; + if (offset_frame < 0) offset_frame = 0; + else if (offset_frame >= num_frames) + offset_frame = num_frames - 1; + BaseFloat scale = scales_(j + max_offset); + if (scale != 0.0) + output.AddVec(scale, input_feats.Row(offset_frame)); + } + } +} + +void ComputeDeltas(const DeltaFeaturesOptions &delta_opts, + const MatrixBase &input_features, + Matrix *output_features) { + output_features->Resize(input_features.NumRows(), + input_features.NumCols() + *(delta_opts.order + 1)); + DeltaFeatures delta(delta_opts); + for (int32 r = 0; r < static_cast(input_features.NumRows()); r++) { + SubVector row(*output_features, r); + delta.Process(input_features, r, &row); + } +} + +void ComputeShiftedDeltas(const ShiftedDeltaFeaturesOptions &delta_opts, + const MatrixBase &input_features, + Matrix *output_features) { + output_features->Resize(input_features.NumRows(), + input_features.NumCols() + * (delta_opts.num_blocks + 1)); + ShiftedDeltaFeatures delta(delta_opts); + + for (int32 r = 0; r < static_cast(input_features.NumRows()); r++) { + SubVector row(*output_features, r); + delta.Process(input_features, r, &row); + } +} + + +void InitIdftBases(int32 n_bases, int32 dimension, Matrix *mat_out) { + BaseFloat angle = M_PI / static_cast(dimension - 1); + BaseFloat scale = 1.0f / (2.0 * static_cast(dimension - 1)); + mat_out->Resize(n_bases, dimension); + for (int32 i = 0; i < n_bases; i++) { + (*mat_out)(i, 0) = 1.0 * scale; + BaseFloat i_fl = static_cast(i); + for (int32 j = 1; j < dimension - 1; j++) { + BaseFloat j_fl = static_cast(j); + (*mat_out)(i, j) = 2.0 * scale * cos(angle * i_fl * j_fl); + } + + (*mat_out)(i, dimension -1) + = scale * cos(angle * i_fl * static_cast(dimension-1)); + } +} + +void SpliceFrames(const MatrixBase &input_features, + int32 left_context, + int32 right_context, + Matrix *output_features) { + int32 T = input_features.NumRows(), D = input_features.NumCols(); + if (T == 0 || D == 0) + KALDI_ERR << "SpliceFrames: empty input"; + KALDI_ASSERT(left_context >= 0 && right_context >= 0); + int32 N = 1 + left_context + right_context; + output_features->Resize(T, D*N); + for (int32 t = 0; t < T; t++) { + SubVector dst_row(*output_features, t); + for (int32 j = 0; j < N; j++) { + int32 t2 = t + j - left_context; + if (t2 < 0) t2 = 0; + if (t2 >= T) t2 = T-1; + SubVector dst(dst_row, j*D, D), + src(input_features, t2); + dst.CopyFromVec(src); + } + } +} + +void ReverseFrames(const MatrixBase &input_features, + Matrix *output_features) { + int32 T = input_features.NumRows(), D = input_features.NumCols(); + if (T == 0 || D == 0) + KALDI_ERR << "ReverseFrames: empty input"; + output_features->Resize(T, D); + for (int32 t = 0; t < T; t++) { + SubVector dst_row(*output_features, t); + SubVector src_row(input_features, T-1-t); + dst_row.CopyFromVec(src_row); + } +} + + +void SlidingWindowCmnOptions::Check() const { + KALDI_ASSERT(cmn_window > 0); + if (center) + KALDI_ASSERT(min_window > 0 && min_window <= cmn_window); + // else ignored so value doesn't matter. +} + +// Internal version of SlidingWindowCmn with double-precision arguments. +void SlidingWindowCmnInternal(const SlidingWindowCmnOptions &opts, + const MatrixBase &input, + MatrixBase *output) { + opts.Check(); + int32 num_frames = input.NumRows(), dim = input.NumCols(), + last_window_start = -1, last_window_end = -1, + warning_count = 0; + Vector cur_sum(dim), cur_sumsq(dim); + + for (int32 t = 0; t < num_frames; t++) { + int32 window_start, window_end; // note: window_end will be one + // past the end of the window we use for normalization. + if (opts.center) { + window_start = t - (opts.cmn_window / 2); + window_end = window_start + opts.cmn_window; + } else { + window_start = t - opts.cmn_window; + window_end = t + 1; + } + if (window_start < 0) { // shift window right if starts <0. + window_end -= window_start; + window_start = 0; // or: window_start -= window_start + } + if (!opts.center) { + if (window_end > t) + window_end = std::max(t + 1, opts.min_window); + } + if (window_end > num_frames) { + window_start -= (window_end - num_frames); + window_end = num_frames; + if (window_start < 0) window_start = 0; + } + if (last_window_start == -1) { + SubMatrix input_part(input, + window_start, window_end - window_start, + 0, dim); + cur_sum.AddRowSumMat(1.0, input_part , 0.0); + if (opts.normalize_variance) + cur_sumsq.AddDiagMat2(1.0, input_part, kTrans, 0.0); + } else { + if (window_start > last_window_start) { + KALDI_ASSERT(window_start == last_window_start + 1); + SubVector frame_to_remove(input, last_window_start); + cur_sum.AddVec(-1.0, frame_to_remove); + if (opts.normalize_variance) + cur_sumsq.AddVec2(-1.0, frame_to_remove); + } + if (window_end > last_window_end) { + KALDI_ASSERT(window_end == last_window_end + 1); + SubVector frame_to_add(input, last_window_end); + cur_sum.AddVec(1.0, frame_to_add); + if (opts.normalize_variance) + cur_sumsq.AddVec2(1.0, frame_to_add); + } + } + int32 window_frames = window_end - window_start; + last_window_start = window_start; + last_window_end = window_end; + + KALDI_ASSERT(window_frames > 0); + SubVector input_frame(input, t), + output_frame(*output, t); + output_frame.CopyFromVec(input_frame); + output_frame.AddVec(-1.0 / window_frames, cur_sum); + + if (opts.normalize_variance) { + if (window_frames == 1) { + output_frame.Set(0.0); + } else { + Vector variance(cur_sumsq); + variance.Scale(1.0 / window_frames); + variance.AddVec2(-1.0 / (window_frames * window_frames), cur_sum); + // now "variance" is the variance of the features in the window, + // around their own mean. + int32 num_floored; + variance.ApplyFloor(1.0e-10, &num_floored); + if (num_floored > 0 && num_frames > 1) { + if (opts.max_warnings == warning_count) { + KALDI_WARN << "Suppressing the remaining variance flooring " + << "warnings. Run program with --max-warnings=-1 to " + << "see all warnings."; + } + // If opts.max_warnings is a negative number, we won't restrict the + // number of times that the warning is printed out. + else if (opts.max_warnings < 0 + || opts.max_warnings > warning_count) { + KALDI_WARN << "Flooring when normalizing variance, floored " + << num_floored << " elements; num-frames was " + << window_frames; + } + warning_count++; + } + variance.ApplyPow(-0.5); // get inverse standard deviation. + output_frame.MulElements(variance); + } + } + } +} + + +void SlidingWindowCmn(const SlidingWindowCmnOptions &opts, + const MatrixBase &input, + MatrixBase *output) { + KALDI_ASSERT(SameDim(input, *output) && input.NumRows() > 0); + Matrix input_dbl(input), output_dbl(input.NumRows(), input.NumCols()); + // call double-precision version + SlidingWindowCmnInternal(opts, input_dbl, &output_dbl); + output->CopyFromMat(output_dbl); +} + + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-functions.h b/speechx/speechx/kaldi/feat/feature-functions.h new file mode 100644 index 00000000..52454f30 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-functions.h @@ -0,0 +1,204 @@ +// feat/feature-functions.h + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Microsoft Corporation +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_FEAT_FEATURE_FUNCTIONS_H_ +#define KALDI_FEAT_FEATURE_FUNCTIONS_H_ + +#include +#include + +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + +// ComputePowerSpectrum converts a complex FFT (as produced by the FFT +// functions in matrix/matrix-functions.h), and converts it into +// a power spectrum. If the complex FFT is a vector of size n (representing +// half the complex FFT of a real signal of size n, as described there), +// this function computes in the first (n/2) + 1 elements of it, the +// energies of the fft bins from zero to the Nyquist frequency. Contents of the +// remaining (n/2) - 1 elements are undefined at output. +void ComputePowerSpectrum(VectorBase *complex_fft); + + +struct DeltaFeaturesOptions { + int32 order; + int32 window; // e.g. 2; controls window size (window size is 2*window + 1) + // the behavior at the edges is to replicate the first or last frame. + // this is not configurable. + + DeltaFeaturesOptions(int32 order = 2, int32 window = 2): + order(order), window(window) { } + void Register(OptionsItf *opts) { + opts->Register("delta-order", &order, "Order of delta computation"); + opts->Register("delta-window", &window, + "Parameter controlling window for delta computation (actual window" + " size for each delta order is 1 + 2*delta-window-size)"); + } +}; + +class DeltaFeatures { + public: + // This class provides a low-level function to compute delta features. + // The function takes as input a matrix of features and a frame index + // that it should compute the deltas on. It puts its output in an object + // of type VectorBase, of size (original-feature-dimension) * (opts.order+1). + // This is not the most efficient way to do the computation, but it's + // state-free and thus easier to understand + + explicit DeltaFeatures(const DeltaFeaturesOptions &opts); + + void Process(const MatrixBase &input_feats, + int32 frame, + VectorBase *output_frame) const; + private: + DeltaFeaturesOptions opts_; + std::vector > scales_; // a scaling window for each + // of the orders, including zero: multiply the features for each + // dimension by this window. +}; + +struct ShiftedDeltaFeaturesOptions { + int32 window, // The time delay and advance + num_blocks, + block_shift; // Distance between consecutive blocks + + ShiftedDeltaFeaturesOptions(): + window(1), num_blocks(7), block_shift(3) { } + void Register(OptionsItf *opts) { + opts->Register("delta-window", &window, "Size of delta advance and delay."); + opts->Register("num-blocks", &num_blocks, "Number of delta blocks in advance" + " of each frame to be concatenated"); + opts->Register("block-shift", &block_shift, "Distance between each block"); + } +}; + +class ShiftedDeltaFeatures { + public: + // This class provides a low-level function to compute shifted + // delta cesptra (SDC). + // The function takes as input a matrix of features and a frame index + // that it should compute the deltas on. It puts its output in an object + // of type VectorBase, of size original-feature-dimension + (1 * num_blocks). + + explicit ShiftedDeltaFeatures(const ShiftedDeltaFeaturesOptions &opts); + + void Process(const MatrixBase &input_feats, + int32 frame, + SubVector *output_frame) const; + private: + ShiftedDeltaFeaturesOptions opts_; + Vector scales_; // a scaling window for each + +}; + +// ComputeDeltas is a convenience function that computes deltas on a feature +// file. If you want to deal with features coming in bit by bit you would have +// to use the DeltaFeatures class directly, and do the computation frame by +// frame. Later we will have to come up with a nice mechanism to do this for +// features coming in. +void ComputeDeltas(const DeltaFeaturesOptions &delta_opts, + const MatrixBase &input_features, + Matrix *output_features); + +// ComputeShiftedDeltas computes deltas from a feature file by applying +// ShiftedDeltaFeatures over the frames. This function is provided for +// convenience, however, ShiftedDeltaFeatures can be used directly. +void ComputeShiftedDeltas(const ShiftedDeltaFeaturesOptions &delta_opts, + const MatrixBase &input_features, + Matrix *output_features); + +// SpliceFrames will normally be used together with LDA. +// It splices frames together to make a window. At the +// start and end of an utterance, it duplicates the first +// and last frames. +// Will throw if input features are empty. +// left_context and right_context must be nonnegative. +// these both represent a number of frames (e.g. 4, 4 is +// a good choice). +void SpliceFrames(const MatrixBase &input_features, + int32 left_context, + int32 right_context, + Matrix *output_features); + +// ReverseFrames reverses the frames in time (used for backwards decoding) +void ReverseFrames(const MatrixBase &input_features, + Matrix *output_features); + + +void InitIdftBases(int32 n_bases, int32 dimension, Matrix *mat_out); + + +// This is used for speaker-id. Also see OnlineCmnOptions in ../online2/, which +// is online CMN with no latency, for online speech recognition. +struct SlidingWindowCmnOptions { + int32 cmn_window; + int32 min_window; + int32 max_warnings; + bool normalize_variance; + bool center; + + SlidingWindowCmnOptions(): + cmn_window(600), + min_window(100), + max_warnings(5), + normalize_variance(false), + center(false) { } + + void Register(OptionsItf *opts) { + opts->Register("cmn-window", &cmn_window, "Window in frames for running " + "average CMN computation"); + opts->Register("min-cmn-window", &min_window, "Minimum CMN window " + "used at start of decoding (adds latency only at start). " + "Only applicable if center == false, ignored if center==true"); + opts->Register("max-warnings", &max_warnings, "Maximum warnings to report " + "per utterance. 0 to disable, -1 to show all."); + opts->Register("norm-vars", &normalize_variance, "If true, normalize " + "variance to one."); // naming this as in apply-cmvn.cc + opts->Register("center", ¢er, "If true, use a window centered on the " + "current frame (to the extent possible, modulo end effects). " + "If false, window is to the left."); + } + void Check() const; +}; + + +/// Applies sliding-window cepstral mean and/or variance normalization. See the +/// strings registering the options in the options class for information on how +/// this works and what the options are. input and output must have the same +/// dimension. +void SlidingWindowCmn(const SlidingWindowCmnOptions &opts, + const MatrixBase &input, + MatrixBase *output); + + +/// @} End of "addtogroup feat" +} // namespace kaldi + + + +#endif // KALDI_FEAT_FEATURE_FUNCTIONS_H_ diff --git a/speechx/speechx/kaldi/feat/feature-mfcc.cc b/speechx/speechx/kaldi/feat/feature-mfcc.cc new file mode 100644 index 00000000..73ab4b31 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-mfcc.cc @@ -0,0 +1,157 @@ +// feat/feature-mfcc.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek +// 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "feat/feature-mfcc.h" + + +namespace kaldi { + + +void MfccComputer::Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature) { + KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && + feature->Dim() == this->Dim()); + + const MelBanks &mel_banks = *(GetMelBanks(vtln_warp)); + + if (opts_.use_energy && !opts_.raw_energy) + signal_raw_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), + std::numeric_limits::epsilon())); + + if (srfft_ != NULL) // Compute FFT using the split-radix algorithm. + srfft_->Compute(signal_frame->Data(), true); + else // An alternative algorithm that works for non-powers-of-two. + RealFft(signal_frame, true); + + // Convert the FFT into a power spectrum. + ComputePowerSpectrum(signal_frame); + SubVector power_spectrum(*signal_frame, 0, + signal_frame->Dim() / 2 + 1); + + mel_banks.Compute(power_spectrum, &mel_energies_); + + // avoid log of zero (which should be prevented anyway by dithering). + mel_energies_.ApplyFloor(std::numeric_limits::epsilon()); + mel_energies_.ApplyLog(); // take the log. + + feature->SetZero(); // in case there were NaNs. + // feature = dct_matrix_ * mel_energies [which now have log] + feature->AddMatVec(1.0, dct_matrix_, kNoTrans, mel_energies_, 0.0); + + if (opts_.cepstral_lifter != 0.0) + feature->MulElements(lifter_coeffs_); + + if (opts_.use_energy) { + if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) + signal_raw_log_energy = log_energy_floor_; + (*feature)(0) = signal_raw_log_energy; + } + + if (opts_.htk_compat) { + BaseFloat energy = (*feature)(0); + for (int32 i = 0; i < opts_.num_ceps - 1; i++) + (*feature)(i) = (*feature)(i+1); + if (!opts_.use_energy) + energy *= M_SQRT2; // scale on C0 (actually removing a scale + // we previously added that's part of one common definition of + // the cosine transform.) + (*feature)(opts_.num_ceps - 1) = energy; + } +} + +MfccComputer::MfccComputer(const MfccOptions &opts): + opts_(opts), srfft_(NULL), + mel_energies_(opts.mel_opts.num_bins) { + + int32 num_bins = opts.mel_opts.num_bins; + if (opts.num_ceps > num_bins) + KALDI_ERR << "num-ceps cannot be larger than num-mel-bins." + << " It should be smaller or equal. You provided num-ceps: " + << opts.num_ceps << " and num-mel-bins: " + << num_bins; + + Matrix dct_matrix(num_bins, num_bins); + ComputeDctMatrix(&dct_matrix); + // Note that we include zeroth dct in either case. If using the + // energy we replace this with the energy. This means a different + // ordering of features than HTK. + SubMatrix dct_rows(dct_matrix, 0, opts.num_ceps, 0, num_bins); + dct_matrix_.Resize(opts.num_ceps, num_bins); + dct_matrix_.CopyFromMat(dct_rows); // subset of rows. + if (opts.cepstral_lifter != 0.0) { + lifter_coeffs_.Resize(opts.num_ceps); + ComputeLifterCoeffs(opts.cepstral_lifter, &lifter_coeffs_); + } + if (opts.energy_floor > 0.0) + log_energy_floor_ = Log(opts.energy_floor); + + int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); + if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... + srfft_ = new SplitRadixRealFft(padded_window_size); + + // We'll definitely need the filterbanks info for VTLN warping factor 1.0. + // [note: this call caches it.] + GetMelBanks(1.0); +} + +MfccComputer::MfccComputer(const MfccComputer &other): + opts_(other.opts_), lifter_coeffs_(other.lifter_coeffs_), + dct_matrix_(other.dct_matrix_), + log_energy_floor_(other.log_energy_floor_), + mel_banks_(other.mel_banks_), + srfft_(NULL), + mel_energies_(other.mel_energies_.Dim(), kUndefined) { + for (std::map::iterator iter = mel_banks_.begin(); + iter != mel_banks_.end(); ++iter) + iter->second = new MelBanks(*(iter->second)); + if (other.srfft_ != NULL) + srfft_ = new SplitRadixRealFft(*(other.srfft_)); +} + + + +MfccComputer::~MfccComputer() { + for (std::map::iterator iter = mel_banks_.begin(); + iter != mel_banks_.end(); + ++iter) + delete iter->second; + delete srfft_; +} + +const MelBanks *MfccComputer::GetMelBanks(BaseFloat vtln_warp) { + MelBanks *this_mel_banks = NULL; + std::map::iterator iter = mel_banks_.find(vtln_warp); + if (iter == mel_banks_.end()) { + this_mel_banks = new MelBanks(opts_.mel_opts, + opts_.frame_opts, + vtln_warp); + mel_banks_[vtln_warp] = this_mel_banks; + } else { + this_mel_banks = iter->second; + } + return this_mel_banks; +} + + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-mfcc.h b/speechx/speechx/kaldi/feat/feature-mfcc.h new file mode 100644 index 00000000..dbfb9d60 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-mfcc.h @@ -0,0 +1,154 @@ +// feat/feature-mfcc.h + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Saarland University +// 2014-2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_MFCC_H_ +#define KALDI_FEAT_FEATURE_MFCC_H_ + +#include +#include + +#include "feat/feature-common.h" +#include "feat/feature-functions.h" +#include "feat/feature-window.h" +#include "feat/mel-computations.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + +/// MfccOptions contains basic options for computing MFCC features. +struct MfccOptions { + FrameExtractionOptions frame_opts; + MelBanksOptions mel_opts; + int32 num_ceps; // e.g. 13: num cepstral coeffs, counting zero. + bool use_energy; // use energy; else C0 + BaseFloat energy_floor; // 0 by default; set to a value like 1.0 or 0.1 if + // you disable dithering. + bool raw_energy; // If true, compute energy before preemphasis and windowing + BaseFloat cepstral_lifter; // Scaling factor on cepstra for HTK compatibility. + // if 0.0, no liftering is done. + bool htk_compat; // if true, put energy/C0 last and introduce a factor of + // sqrt(2) on C0 to be the same as HTK. + + MfccOptions() : mel_opts(23), + // defaults the #mel-banks to 23 for the MFCC computations. + // this seems to be common for 16khz-sampled data, + // but for 8khz-sampled data, 15 may be better. + num_ceps(13), + use_energy(true), + energy_floor(0.0), + raw_energy(true), + cepstral_lifter(22.0), + htk_compat(false) {} + + void Register(OptionsItf *opts) { + frame_opts.Register(opts); + mel_opts.Register(opts); + opts->Register("num-ceps", &num_ceps, + "Number of cepstra in MFCC computation (including C0)"); + opts->Register("use-energy", &use_energy, + "Use energy (not C0) in MFCC computation"); + opts->Register("energy-floor", &energy_floor, + "Floor on energy (absolute, not relative) in MFCC computation. " + "Only makes a difference if --use-energy=true; only necessary if " + "--dither=0.0. Suggested values: 0.1 or 1.0"); + opts->Register("raw-energy", &raw_energy, + "If true, compute energy before preemphasis and windowing"); + opts->Register("cepstral-lifter", &cepstral_lifter, + "Constant that controls scaling of MFCCs"); + opts->Register("htk-compat", &htk_compat, + "If true, put energy or C0 last and use a factor of sqrt(2) on " + "C0. Warning: not sufficient to get HTK compatible features " + "(need to change other parameters)."); + } +}; + + + +// This is the new-style interface to the MFCC computation. +class MfccComputer { + public: + typedef MfccOptions Options; + explicit MfccComputer(const MfccOptions &opts); + MfccComputer(const MfccComputer &other); + + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + int32 Dim() const { return opts_.num_ceps; } + + bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedsRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. + */ + void Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature); + + ~MfccComputer(); + private: + // disallow assignment. + MfccComputer &operator = (const MfccComputer &in); + + protected: + const MelBanks *GetMelBanks(BaseFloat vtln_warp); + + MfccOptions opts_; + Vector lifter_coeffs_; + Matrix dct_matrix_; // matrix we left-multiply by to perform DCT. + BaseFloat log_energy_floor_; + std::map mel_banks_; // BaseFloat is VTLN coefficient. + SplitRadixRealFft *srfft_; + + // note: mel_energies_ is specific to the frame we're processing, it's + // just a temporary workspace. + Vector mel_energies_; +}; + +typedef OfflineFeatureTpl Mfcc; + + +/// @} End of "addtogroup feat" +} // namespace kaldi + + +#endif // KALDI_FEAT_FEATURE_MFCC_H_ diff --git a/speechx/speechx/kaldi/feat/feature-plp.cc b/speechx/speechx/kaldi/feat/feature-plp.cc new file mode 100644 index 00000000..e0c270c7 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-plp.cc @@ -0,0 +1,191 @@ +// feat/feature-plp.cc + +// Copyright 2009-2011 Petr Motlicek; Karel Vesely +// 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "feat/feature-plp.h" + +namespace kaldi { + +PlpComputer::PlpComputer(const PlpOptions &opts): + opts_(opts), srfft_(NULL), + mel_energies_duplicated_(opts_.mel_opts.num_bins + 2, kUndefined), + autocorr_coeffs_(opts_.lpc_order + 1, kUndefined), + lpc_coeffs_(opts_.lpc_order, kUndefined), + raw_cepstrum_(opts_.lpc_order, kUndefined) { + + if (opts.cepstral_lifter != 0.0) { + lifter_coeffs_.Resize(opts.num_ceps); + ComputeLifterCoeffs(opts.cepstral_lifter, &lifter_coeffs_); + } + InitIdftBases(opts_.lpc_order + 1, opts_.mel_opts.num_bins + 2, + &idft_bases_); + + if (opts.energy_floor > 0.0) + log_energy_floor_ = Log(opts.energy_floor); + + int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); + if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two... + srfft_ = new SplitRadixRealFft(padded_window_size); + + // We'll definitely need the filterbanks info for VTLN warping factor 1.0. + // [note: this call caches it.] + GetMelBanks(1.0); +} + +PlpComputer::PlpComputer(const PlpComputer &other): + opts_(other.opts_), lifter_coeffs_(other.lifter_coeffs_), + idft_bases_(other.idft_bases_), log_energy_floor_(other.log_energy_floor_), + mel_banks_(other.mel_banks_), equal_loudness_(other.equal_loudness_), + srfft_(NULL), + mel_energies_duplicated_(opts_.mel_opts.num_bins + 2, kUndefined), + autocorr_coeffs_(opts_.lpc_order + 1, kUndefined), + lpc_coeffs_(opts_.lpc_order, kUndefined), + raw_cepstrum_(opts_.lpc_order, kUndefined) { + for (std::map::iterator iter = mel_banks_.begin(); + iter != mel_banks_.end(); ++iter) + iter->second = new MelBanks(*(iter->second)); + for (std::map*>::iterator + iter = equal_loudness_.begin(); + iter != equal_loudness_.end(); ++iter) + iter->second = new Vector(*(iter->second)); + if (other.srfft_ != NULL) + srfft_ = new SplitRadixRealFft(*(other.srfft_)); +} + +PlpComputer::~PlpComputer() { + for (std::map::iterator iter = mel_banks_.begin(); + iter != mel_banks_.end(); ++iter) + delete iter->second; + for (std::map* >::iterator + iter = equal_loudness_.begin(); + iter != equal_loudness_.end(); ++iter) + delete iter->second; + delete srfft_; +} + +const MelBanks *PlpComputer::GetMelBanks(BaseFloat vtln_warp) { + MelBanks *this_mel_banks = NULL; + std::map::iterator iter = mel_banks_.find(vtln_warp); + if (iter == mel_banks_.end()) { + this_mel_banks = new MelBanks(opts_.mel_opts, + opts_.frame_opts, + vtln_warp); + mel_banks_[vtln_warp] = this_mel_banks; + } else { + this_mel_banks = iter->second; + } + return this_mel_banks; +} + +const Vector *PlpComputer::GetEqualLoudness(BaseFloat vtln_warp) { + const MelBanks *this_mel_banks = GetMelBanks(vtln_warp); + Vector *ans = NULL; + std::map*>::iterator iter + = equal_loudness_.find(vtln_warp); + if (iter == equal_loudness_.end()) { + ans = new Vector; + GetEqualLoudnessVector(*this_mel_banks, ans); + equal_loudness_[vtln_warp] = ans; + } else { + ans = iter->second; + } + return ans; +} + +void PlpComputer::Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature) { + KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && + feature->Dim() == this->Dim()); + + const MelBanks &mel_banks = *GetMelBanks(vtln_warp); + const Vector &equal_loudness = *GetEqualLoudness(vtln_warp); + + + KALDI_ASSERT(opts_.num_ceps <= opts_.lpc_order+1); // our num-ceps includes C0. + + + if (opts_.use_energy && !opts_.raw_energy) + signal_raw_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), + std::numeric_limits::min())); + + if (srfft_ != NULL) // Compute FFT using split-radix algorithm. + srfft_->Compute(signal_frame->Data(), true); + else // An alternative algorithm that works for non-powers-of-two. + RealFft(signal_frame, true); + + // Convert the FFT into a power spectrum. + ComputePowerSpectrum(signal_frame); // elements 0 ... signal_frame->Dim()/2 + + SubVector power_spectrum(*signal_frame, + 0, signal_frame->Dim() / 2 + 1); + + int32 num_mel_bins = opts_.mel_opts.num_bins; + + SubVector mel_energies(mel_energies_duplicated_, 1, num_mel_bins); + + mel_banks.Compute(power_spectrum, &mel_energies); + + mel_energies.MulElements(equal_loudness); + + mel_energies.ApplyPow(opts_.compress_factor); + + // duplicate first and last elements + mel_energies_duplicated_(0) = mel_energies_duplicated_(1); + mel_energies_duplicated_(num_mel_bins + 1) = + mel_energies_duplicated_(num_mel_bins); + + autocorr_coeffs_.SetZero(); // In case of NaNs or infs + autocorr_coeffs_.AddMatVec(1.0, idft_bases_, kNoTrans, + mel_energies_duplicated_, 0.0); + + BaseFloat residual_log_energy = ComputeLpc(autocorr_coeffs_, &lpc_coeffs_); + + residual_log_energy = std::max(residual_log_energy, + std::numeric_limits::min()); + + Lpc2Cepstrum(opts_.lpc_order, lpc_coeffs_.Data(), raw_cepstrum_.Data()); + feature->Range(1, opts_.num_ceps - 1).CopyFromVec( + raw_cepstrum_.Range(0, opts_.num_ceps - 1)); + (*feature)(0) = residual_log_energy; + + if (opts_.cepstral_lifter != 0.0) + feature->MulElements(lifter_coeffs_); + + if (opts_.cepstral_scale != 1.0) + feature->Scale(opts_.cepstral_scale); + + if (opts_.use_energy) { + if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) + signal_raw_log_energy = log_energy_floor_; + (*feature)(0) = signal_raw_log_energy; + } + + if (opts_.htk_compat) { // reorder the features. + BaseFloat log_energy = (*feature)(0); + for (int32 i = 0; i < opts_.num_ceps-1; i++) + (*feature)(i) = (*feature)(i+1); + (*feature)(opts_.num_ceps-1) = log_energy; + } +} + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-plp.h b/speechx/speechx/kaldi/feat/feature-plp.h new file mode 100644 index 00000000..4f156ca1 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-plp.h @@ -0,0 +1,176 @@ +// feat/feature-plp.h + +// Copyright 2009-2011 Petr Motlicek; Karel Vesely + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_PLP_H_ +#define KALDI_FEAT_FEATURE_PLP_H_ + +#include +#include + +#include "feat/feature-common.h" +#include "feat/feature-functions.h" +#include "feat/feature-window.h" +#include "feat/mel-computations.h" +#include "itf/options-itf.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + + +/// PlpOptions contains basic options for computing PLP features. +/// It only includes things that can be done in a "stateless" way, i.e. +/// it does not include energy max-normalization. +/// It does not include delta computation. +struct PlpOptions { + FrameExtractionOptions frame_opts; + MelBanksOptions mel_opts; + int32 lpc_order; + int32 num_ceps; // num cepstra including zero + bool use_energy; // use energy; else C0 + BaseFloat energy_floor; + bool raw_energy; // If true, compute energy before preemphasis and windowing + BaseFloat compress_factor; + int32 cepstral_lifter; + BaseFloat cepstral_scale; + + bool htk_compat; // if true, put energy/C0 last and introduce a factor of + // sqrt(2) on C0 to be the same as HTK. + + PlpOptions() : mel_opts(23), + // default number of mel-banks for the PLP computation; this + // seems to be common for 16kHz-sampled data. For 8kHz-sampled + // data, 15 may be better. + lpc_order(12), + num_ceps(13), + use_energy(true), + energy_floor(0.0), + raw_energy(true), + compress_factor(0.33333), + cepstral_lifter(22), + cepstral_scale(1.0), + htk_compat(false) {} + + void Register(OptionsItf *opts) { + frame_opts.Register(opts); + mel_opts.Register(opts); + opts->Register("lpc-order", &lpc_order, + "Order of LPC analysis in PLP computation"); + opts->Register("num-ceps", &num_ceps, + "Number of cepstra in PLP computation (including C0)"); + opts->Register("use-energy", &use_energy, + "Use energy (not C0) for zeroth PLP feature"); + opts->Register("energy-floor", &energy_floor, + "Floor on energy (absolute, not relative) in PLP computation. " + "Only makes a difference if --use-energy=true; only necessary if " + "--dither=0.0. Suggested values: 0.1 or 1.0"); + opts->Register("raw-energy", &raw_energy, + "If true, compute energy before preemphasis and windowing"); + opts->Register("compress-factor", &compress_factor, + "Compression factor in PLP computation"); + opts->Register("cepstral-lifter", &cepstral_lifter, + "Constant that controls scaling of PLPs"); + opts->Register("cepstral-scale", &cepstral_scale, + "Scaling constant in PLP computation"); + opts->Register("htk-compat", &htk_compat, + "If true, put energy or C0 last. Warning: not sufficient " + "to get HTK compatible features (need to change other " + "parameters)."); + } +}; + + +/// This is the new-style interface to the PLP computation. +class PlpComputer { + public: + typedef PlpOptions Options; + explicit PlpComputer(const PlpOptions &opts); + PlpComputer(const PlpComputer &other); + + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + int32 Dim() const { return opts_.num_ceps; } + + bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedsRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. + */ + void Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature); + + ~PlpComputer(); + private: + + const MelBanks *GetMelBanks(BaseFloat vtln_warp); + + const Vector *GetEqualLoudness(BaseFloat vtln_warp); + + PlpOptions opts_; + Vector lifter_coeffs_; + Matrix idft_bases_; + BaseFloat log_energy_floor_; + std::map mel_banks_; // BaseFloat is VTLN coefficient. + std::map* > equal_loudness_; + SplitRadixRealFft *srfft_; + + // temporary vector used inside Compute; size is opts_.mel_opts.num_bins + 2 + Vector mel_energies_duplicated_; + // temporary vector used inside Compute; size is opts_.lpc_order + 1 + Vector autocorr_coeffs_; + // temporary vector used inside Compute; size is opts_.lpc_order + Vector lpc_coeffs_; + // temporary vector used inside Compute; size is opts_.lpc_order + Vector raw_cepstrum_; + + // Disallow assignment. + PlpComputer &operator =(const PlpComputer &other); +}; + +typedef OfflineFeatureTpl Plp; + +/// @} End of "addtogroup feat" + +} // namespace kaldi + + +#endif // KALDI_FEAT_FEATURE_PLP_H_ diff --git a/speechx/speechx/kaldi/feat/feature-spectrogram.cc b/speechx/speechx/kaldi/feat/feature-spectrogram.cc new file mode 100644 index 00000000..7eee2643 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-spectrogram.cc @@ -0,0 +1,82 @@ +// feat/feature-spectrogram.cc + +// Copyright 2009-2012 Karel Vesely +// Copyright 2012 Navdeep Jaitly + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "feat/feature-spectrogram.h" + + +namespace kaldi { + +SpectrogramComputer::SpectrogramComputer(const SpectrogramOptions &opts) + : opts_(opts), srfft_(NULL) { + if (opts.energy_floor > 0.0) + log_energy_floor_ = Log(opts.energy_floor); + + int32 padded_window_size = opts.frame_opts.PaddedWindowSize(); + if ((padded_window_size & (padded_window_size-1)) == 0) // Is a power of two + srfft_ = new SplitRadixRealFft(padded_window_size); +} + +SpectrogramComputer::SpectrogramComputer(const SpectrogramComputer &other): + opts_(other.opts_), log_energy_floor_(other.log_energy_floor_), srfft_(NULL) { + if (other.srfft_ != NULL) + srfft_ = new SplitRadixRealFft(*other.srfft_); +} + +SpectrogramComputer::~SpectrogramComputer() { + delete srfft_; +} + +void SpectrogramComputer::Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature) { + KALDI_ASSERT(signal_frame->Dim() == opts_.frame_opts.PaddedWindowSize() && + feature->Dim() == this->Dim()); + + + // Compute energy after window function (not the raw one) + if (!opts_.raw_energy) + signal_raw_log_energy = Log(std::max(VecVec(*signal_frame, *signal_frame), + std::numeric_limits::epsilon())); + + if (srfft_ != NULL) // Compute FFT using split-radix algorithm. + srfft_->Compute(signal_frame->Data(), true); + else // An alternative algorithm that works for non-powers-of-two + RealFft(signal_frame, true); + + // Convert the FFT into a power spectrum. + ComputePowerSpectrum(signal_frame); + SubVector power_spectrum(*signal_frame, + 0, signal_frame->Dim() / 2 + 1); + + power_spectrum.ApplyFloor(std::numeric_limits::epsilon()); + power_spectrum.ApplyLog(); + + feature->CopyFromVec(power_spectrum); + + if (opts_.energy_floor > 0.0 && signal_raw_log_energy < log_energy_floor_) + signal_raw_log_energy = log_energy_floor_; + // The zeroth spectrogram component is always set to the signal energy, + // instead of the square of the constant component of the signal. + (*feature)(0) = signal_raw_log_energy; +} + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-spectrogram.h b/speechx/speechx/kaldi/feat/feature-spectrogram.h new file mode 100644 index 00000000..132a6875 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-spectrogram.h @@ -0,0 +1,117 @@ +// feat/feature-spectrogram.h + +// Copyright 2009-2012 Karel Vesely +// Copyright 2012 Navdeep Jaitly + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_SPECTROGRAM_H_ +#define KALDI_FEAT_FEATURE_SPECTROGRAM_H_ + + +#include + +#include "feat/feature-common.h" +#include "feat/feature-functions.h" +#include "feat/feature-window.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + +/// SpectrogramOptions contains basic options for computing spectrogram +/// features. +struct SpectrogramOptions { + FrameExtractionOptions frame_opts; + BaseFloat energy_floor; + bool raw_energy; // If true, compute energy before preemphasis and windowing + + SpectrogramOptions() : + energy_floor(0.0), + raw_energy(true) {} + + void Register(OptionsItf *opts) { + frame_opts.Register(opts); + opts->Register("energy-floor", &energy_floor, + "Floor on energy (absolute, not relative) in Spectrogram " + "computation. Caution: this floor is applied to the zeroth " + "component, representing the total signal energy. The " + "floor on the individual spectrogram elements is fixed at " + "std::numeric_limits::epsilon()."); + opts->Register("raw-energy", &raw_energy, + "If true, compute energy before preemphasis and windowing"); + } +}; + +/// Class for computing spectrogram features. +class SpectrogramComputer { + public: + typedef SpectrogramOptions Options; + explicit SpectrogramComputer(const SpectrogramOptions &opts); + SpectrogramComputer(const SpectrogramComputer &other); + + const FrameExtractionOptions& GetFrameOptions() const { + return opts_.frame_opts; + } + + int32 Dim() const { return opts_.frame_opts.PaddedWindowSize() / 2 + 1; } + + bool NeedRawLogEnergy() const { return opts_.raw_energy; } + + + /** + Function that computes one frame of spectrogram features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedsRawLogEnergy(). + @param [in] vtln_warp This is ignored by this function, it's only + needed for interface compatibility. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. + */ + void Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature); + + ~SpectrogramComputer(); + + private: + SpectrogramOptions opts_; + BaseFloat log_energy_floor_; + SplitRadixRealFft *srfft_; + + // Disallow assignment. + SpectrogramComputer &operator=(const SpectrogramComputer &other); +}; + +typedef OfflineFeatureTpl Spectrogram; + + +/// @} End of "addtogroup feat" +} // namespace kaldi + + +#endif // KALDI_FEAT_FEATURE_SPECTROGRAM_H_ diff --git a/speechx/speechx/kaldi/feat/feature-window.cc b/speechx/speechx/kaldi/feat/feature-window.cc new file mode 100644 index 00000000..c5d4cc29 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-window.cc @@ -0,0 +1,222 @@ +// feat/feature-window.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Microsoft Corporation +// 2013-2016 Johns Hopkins University (author: Daniel Povey) +// 2014 IMSL, PKU-HKUST (author: Wei Shi) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "feat/feature-window.h" +#include "matrix/matrix-functions.h" + + +namespace kaldi { + + +int64 FirstSampleOfFrame(int32 frame, + const FrameExtractionOptions &opts) { + int64 frame_shift = opts.WindowShift(); + if (opts.snip_edges) { + return frame * frame_shift; + } else { + int64 midpoint_of_frame = frame_shift * frame + frame_shift / 2, + beginning_of_frame = midpoint_of_frame - opts.WindowSize() / 2; + return beginning_of_frame; + } +} + +int32 NumFrames(int64 num_samples, + const FrameExtractionOptions &opts, + bool flush) { + int64 frame_shift = opts.WindowShift(); + int64 frame_length = opts.WindowSize(); + if (opts.snip_edges) { + // with --snip-edges=true (the default), we use a HTK-like approach to + // determining the number of frames-- all frames have to fit completely into + // the waveform, and the first frame begins at sample zero. + if (num_samples < frame_length) + return 0; + else + return (1 + ((num_samples - frame_length) / frame_shift)); + // You can understand the expression above as follows: 'num_samples - + // frame_length' is how much room we have to shift the frame within the + // waveform; 'frame_shift' is how much we shift it each time; and the ratio + // is how many times we can shift it (integer arithmetic rounds down). + } else { + // if --snip-edges=false, the number of frames is determined by rounding the + // (file-length / frame-shift) to the nearest integer. The point of this + // formula is to make the number of frames an obvious and predictable + // function of the frame shift and signal length, which makes many + // segmentation-related questions simpler. + // + // Because integer division in C++ rounds toward zero, we add (half the + // frame-shift minus epsilon) before dividing, to have the effect of + // rounding towards the closest integer. + int32 num_frames = (num_samples + (frame_shift / 2)) / frame_shift; + + if (flush) + return num_frames; + + // note: 'end' always means the last plus one, i.e. one past the last. + int64 end_sample_of_last_frame = FirstSampleOfFrame(num_frames - 1, opts) + + frame_length; + + // the following code is optimized more for clarity than efficiency. + // If flush == false, we can't output frames that extend past the end + // of the signal. + while (num_frames > 0 && end_sample_of_last_frame > num_samples) { + num_frames--; + end_sample_of_last_frame -= frame_shift; + } + return num_frames; + } +} + + +void Dither(VectorBase *waveform, BaseFloat dither_value) { + if (dither_value == 0.0) + return; + int32 dim = waveform->Dim(); + BaseFloat *data = waveform->Data(); + RandomState rstate; + for (int32 i = 0; i < dim; i++) + data[i] += RandGauss(&rstate) * dither_value; +} + + +void Preemphasize(VectorBase *waveform, BaseFloat preemph_coeff) { + if (preemph_coeff == 0.0) return; + KALDI_ASSERT(preemph_coeff >= 0.0 && preemph_coeff <= 1.0); + for (int32 i = waveform->Dim()-1; i > 0; i--) + (*waveform)(i) -= preemph_coeff * (*waveform)(i-1); + (*waveform)(0) -= preemph_coeff * (*waveform)(0); +} + +FeatureWindowFunction::FeatureWindowFunction(const FrameExtractionOptions &opts) { + int32 frame_length = opts.WindowSize(); + KALDI_ASSERT(frame_length > 0); + window.Resize(frame_length); + double a = M_2PI / (frame_length-1); + for (int32 i = 0; i < frame_length; i++) { + double i_fl = static_cast(i); + if (opts.window_type == "hanning") { + window(i) = 0.5 - 0.5*cos(a * i_fl); + } else if (opts.window_type == "hamming") { + window(i) = 0.54 - 0.46*cos(a * i_fl); + } else if (opts.window_type == "povey") { // like hamming but goes to zero at edges. + window(i) = pow(0.5 - 0.5*cos(a * i_fl), 0.85); + } else if (opts.window_type == "rectangular") { + window(i) = 1.0; + } else if (opts.window_type == "blackman") { + window(i) = opts.blackman_coeff - 0.5*cos(a * i_fl) + + (0.5 - opts.blackman_coeff) * cos(2 * a * i_fl); + } else { + KALDI_ERR << "Invalid window type " << opts.window_type; + } + } +} + +void ProcessWindow(const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + VectorBase *window, + BaseFloat *log_energy_pre_window) { + int32 frame_length = opts.WindowSize(); + KALDI_ASSERT(window->Dim() == frame_length); + + if (opts.dither != 0.0) + Dither(window, opts.dither); + + if (opts.remove_dc_offset) + window->Add(-window->Sum() / frame_length); + + if (log_energy_pre_window != NULL) { + BaseFloat energy = std::max(VecVec(*window, *window), + std::numeric_limits::epsilon()); + *log_energy_pre_window = Log(energy); + } + + if (opts.preemph_coeff != 0.0) + Preemphasize(window, opts.preemph_coeff); + + window->MulElements(window_function.window); +} + + +// ExtractWindow extracts a windowed frame of waveform with a power-of-two, +// padded size. It does mean subtraction, pre-emphasis and dithering as +// requested. +void ExtractWindow(int64 sample_offset, + const VectorBase &wave, + int32 f, // with 0 <= f < NumFrames(feats, opts) + const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + Vector *window, + BaseFloat *log_energy_pre_window) { + KALDI_ASSERT(sample_offset >= 0 && wave.Dim() != 0); + int32 frame_length = opts.WindowSize(), + frame_length_padded = opts.PaddedWindowSize(); + int64 num_samples = sample_offset + wave.Dim(), + start_sample = FirstSampleOfFrame(f, opts), + end_sample = start_sample + frame_length; + + if (opts.snip_edges) { + KALDI_ASSERT(start_sample >= sample_offset && + end_sample <= num_samples); + } else { + KALDI_ASSERT(sample_offset == 0 || start_sample >= sample_offset); + } + + if (window->Dim() != frame_length_padded) + window->Resize(frame_length_padded, kUndefined); + + // wave_start and wave_end are start and end indexes into 'wave', for the + // piece of wave that we're trying to extract. + int32 wave_start = int32(start_sample - sample_offset), + wave_end = wave_start + frame_length; + if (wave_start >= 0 && wave_end <= wave.Dim()) { + // the normal case-- no edge effects to consider. + window->Range(0, frame_length).CopyFromVec( + wave.Range(wave_start, frame_length)); + } else { + // Deal with any end effects by reflection, if needed. This code will only + // be reached for about two frames per utterance, so we don't concern + // ourselves excessively with efficiency. + int32 wave_dim = wave.Dim(); + for (int32 s = 0; s < frame_length; s++) { + int32 s_in_wave = s + wave_start; + while (s_in_wave < 0 || s_in_wave >= wave_dim) { + // reflect around the beginning or end of the wave. + // e.g. -1 -> 0, -2 -> 1. + // dim -> dim - 1, dim + 1 -> dim - 2. + // the code supports repeated reflections, although this + // would only be needed in pathological cases. + if (s_in_wave < 0) s_in_wave = - s_in_wave - 1; + else s_in_wave = 2 * wave_dim - 1 - s_in_wave; + } + (*window)(s) = wave(s_in_wave); + } + } + + if (frame_length_padded > frame_length) + window->Range(frame_length, frame_length_padded - frame_length).SetZero(); + + SubVector frame(*window, 0, frame_length); + + ProcessWindow(opts, window_function, &frame, log_energy_pre_window); +} + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/feature-window.h b/speechx/speechx/kaldi/feat/feature-window.h new file mode 100644 index 00000000..a7abba50 --- /dev/null +++ b/speechx/speechx/kaldi/feat/feature-window.h @@ -0,0 +1,223 @@ +// feat/feature-window.h + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Saarland University +// 2014-2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_WINDOW_H_ +#define KALDI_FEAT_FEATURE_WINDOW_H_ + +#include +#include + +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + +struct FrameExtractionOptions { + BaseFloat samp_freq; + BaseFloat frame_shift_ms; // in milliseconds. + BaseFloat frame_length_ms; // in milliseconds. + BaseFloat dither; // Amount of dithering, 0.0 means no dither. + BaseFloat preemph_coeff; // Preemphasis coefficient. + bool remove_dc_offset; // Subtract mean of wave before FFT. + std::string window_type; // e.g. Hamming window + // May be "hamming", "rectangular", "povey", "hanning", "blackman" + // "povey" is a window I made to be similar to Hamming but to go to zero at the + // edges, it's pow((0.5 - 0.5*cos(n/N*2*pi)), 0.85) + // I just don't think the Hamming window makes sense as a windowing function. + bool round_to_power_of_two; + BaseFloat blackman_coeff; + bool snip_edges; + bool allow_downsample; + bool allow_upsample; + int max_feature_vectors; + FrameExtractionOptions(): + samp_freq(16000), + frame_shift_ms(10.0), + frame_length_ms(25.0), + dither(1.0), + preemph_coeff(0.97), + remove_dc_offset(true), + window_type("povey"), + round_to_power_of_two(true), + blackman_coeff(0.42), + snip_edges(true), + allow_downsample(false), + allow_upsample(false), + max_feature_vectors(-1) + { } + + void Register(OptionsItf *opts) { + opts->Register("sample-frequency", &samp_freq, + "Waveform data sample frequency (must match the waveform file, " + "if specified there)"); + opts->Register("frame-length", &frame_length_ms, "Frame length in milliseconds"); + opts->Register("frame-shift", &frame_shift_ms, "Frame shift in milliseconds"); + opts->Register("preemphasis-coefficient", &preemph_coeff, + "Coefficient for use in signal preemphasis"); + opts->Register("remove-dc-offset", &remove_dc_offset, + "Subtract mean from waveform on each frame"); + opts->Register("dither", &dither, "Dithering constant (0.0 means no dither). " + "If you turn this off, you should set the --energy-floor " + "option, e.g. to 1.0 or 0.1"); + opts->Register("window-type", &window_type, "Type of window " + "(\"hamming\"|\"hanning\"|\"povey\"|\"rectangular\"" + "|\"blackmann\")"); + opts->Register("blackman-coeff", &blackman_coeff, + "Constant coefficient for generalized Blackman window."); + opts->Register("round-to-power-of-two", &round_to_power_of_two, + "If true, round window size to power of two by zero-padding " + "input to FFT."); + opts->Register("snip-edges", &snip_edges, + "If true, end effects will be handled by outputting only frames that " + "completely fit in the file, and the number of frames depends on the " + "frame-length. If false, the number of frames depends only on the " + "frame-shift, and we reflect the data at the ends."); + opts->Register("allow-downsample", &allow_downsample, + "If true, allow the input waveform to have a higher frequency than " + "the specified --sample-frequency (and we'll downsample)."); + opts->Register("max-feature-vectors", &max_feature_vectors, + "Memory optimization. If larger than 0, periodically remove feature " + "vectors so that only this number of the latest feature vectors is " + "retained."); + opts->Register("allow-upsample", &allow_upsample, + "If true, allow the input waveform to have a lower frequency than " + "the specified --sample-frequency (and we'll upsample)."); + } + int32 WindowShift() const { + return static_cast(samp_freq * 0.001 * frame_shift_ms); + } + int32 WindowSize() const { + return static_cast(samp_freq * 0.001 * frame_length_ms); + } + int32 PaddedWindowSize() const { + return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize()) : + WindowSize()); + } +}; + + +struct FeatureWindowFunction { + FeatureWindowFunction() {} + explicit FeatureWindowFunction(const FrameExtractionOptions &opts); + FeatureWindowFunction(const FeatureWindowFunction &other): + window(other.window) { } + Vector window; +}; + + +/** + This function returns the number of frames that we can extract from a wave + file with the given number of samples in it (assumed to have the same + sampling rate as specified in 'opts'). + + @param [in] num_samples The number of samples in the wave file. + @param [in] opts The frame-extraction options class + + @param [in] flush True if we are asserting that this number of samples is + 'all there is', false if we expecting more data to possibly come + in. This only makes a difference to the answer if opts.snips_edges + == false. For offline feature extraction you always want flush == + true. In an online-decoding context, once you know (or decide) that + no more data is coming in, you'd call it with flush == true at the + end to flush out any remaining data. +*/ +int32 NumFrames(int64 num_samples, + const FrameExtractionOptions &opts, + bool flush = true); + +/* + This function returns the index of the first sample of the frame indexed + 'frame'. If snip-edges=true, it just returns frame * opts.WindowShift(); if + snip-edges=false, the formula is a little more complicated and the result may + be negative. +*/ +int64 FirstSampleOfFrame(int32 frame, + const FrameExtractionOptions &opts); + + + +void Dither(VectorBase *waveform, BaseFloat dither_value); + +void Preemphasize(VectorBase *waveform, BaseFloat preemph_coeff); + +/** + This function does all the windowing steps after actually + extracting the windowed signal: depending on the + configuration, it does dithering, dc offset removal, + preemphasis, and multiplication by the windowing function. + @param [in] opts The options class to be used + @param [in] window_function The windowing function-- should have + been initialized using 'opts'. + @param [in,out] window A vector of size opts.WindowSize(). Note: + it will typically be a sub-vector of a larger vector of size + opts.PaddedWindowSize(), with the remaining samples zero, + as the FFT code is more efficient if it operates on data with + power-of-two size. + @param [out] log_energy_pre_window If non-NULL, then after dithering and + DC offset removal, this function will write to this pointer the log of + the total energy (i.e. sum-squared) of the frame. + */ +void ProcessWindow(const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + VectorBase *window, + BaseFloat *log_energy_pre_window = NULL); + + +/* + ExtractWindow() extracts a windowed frame of waveform (possibly with a + power-of-two, padded size, depending on the config), including all the + proessing done by ProcessWindow(). + + @param [in] sample_offset If 'wave' is not the entire waveform, but + part of it to the left has been discarded, then the + number of samples prior to 'wave' that we have + already discarded. Set this to zero if you are + processing the entire waveform in one piece, or + if you get 'no matching function' compilation + errors when updating the code. + @param [in] wave The waveform + @param [in] f The frame index to be extracted, with + 0 <= f < NumFrames(sample_offset + wave.Dim(), opts, true) + @param [in] opts The options class to be used + @param [in] window_function The windowing function, as derived from the + options class. + @param [out] window The windowed, possibly-padded waveform to be + extracted. Will be resized as needed. + @param [out] log_energy_pre_window If non-NULL, the log-energy of + the signal prior to pre-emphasis and multiplying by + the windowing function will be written to here. +*/ +void ExtractWindow(int64 sample_offset, + const VectorBase &wave, + int32 f, + const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + Vector *window, + BaseFloat *log_energy_pre_window = NULL); + + +/// @} End of "addtogroup feat" +} // namespace kaldi + + +#endif // KALDI_FEAT_FEATURE_WINDOW_H_ diff --git a/speechx/speechx/kaldi/feat/mel-computations.cc b/speechx/speechx/kaldi/feat/mel-computations.cc new file mode 100644 index 00000000..bb5e9f9a --- /dev/null +++ b/speechx/speechx/kaldi/feat/mel-computations.cc @@ -0,0 +1,340 @@ +// feat/mel-computations.cc + +// Copyright 2009-2011 Phonexia s.r.o.; Karel Vesely; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "feat/feature-functions.h" +#include "feat/feature-window.h" +#include "feat/mel-computations.h" + +namespace kaldi { + + +MelBanks::MelBanks(const MelBanksOptions &opts, + const FrameExtractionOptions &frame_opts, + BaseFloat vtln_warp_factor): + htk_mode_(opts.htk_mode) { + int32 num_bins = opts.num_bins; + if (num_bins < 3) KALDI_ERR << "Must have at least 3 mel bins"; + BaseFloat sample_freq = frame_opts.samp_freq; + int32 window_length_padded = frame_opts.PaddedWindowSize(); + KALDI_ASSERT(window_length_padded % 2 == 0); + int32 num_fft_bins = window_length_padded / 2; + BaseFloat nyquist = 0.5 * sample_freq; + + BaseFloat low_freq = opts.low_freq, high_freq; + if (opts.high_freq > 0.0) + high_freq = opts.high_freq; + else + high_freq = nyquist + opts.high_freq; + + if (low_freq < 0.0 || low_freq >= nyquist + || high_freq <= 0.0 || high_freq > nyquist + || high_freq <= low_freq) + KALDI_ERR << "Bad values in options: low-freq " << low_freq + << " and high-freq " << high_freq << " vs. nyquist " + << nyquist; + + BaseFloat fft_bin_width = sample_freq / window_length_padded; + // fft-bin width [think of it as Nyquist-freq / half-window-length] + + BaseFloat mel_low_freq = MelScale(low_freq); + BaseFloat mel_high_freq = MelScale(high_freq); + + debug_ = opts.debug_mel; + + // divide by num_bins+1 in next line because of end-effects where the bins + // spread out to the sides. + BaseFloat mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins+1); + + BaseFloat vtln_low = opts.vtln_low, + vtln_high = opts.vtln_high; + if (vtln_high < 0.0) { + vtln_high += nyquist; + } + + if (vtln_warp_factor != 1.0 && + (vtln_low < 0.0 || vtln_low <= low_freq + || vtln_low >= high_freq + || vtln_high <= 0.0 || vtln_high >= high_freq + || vtln_high <= vtln_low)) + KALDI_ERR << "Bad values in options: vtln-low " << vtln_low + << " and vtln-high " << vtln_high << ", versus " + << "low-freq " << low_freq << " and high-freq " + << high_freq; + + bins_.resize(num_bins); + center_freqs_.Resize(num_bins); + + for (int32 bin = 0; bin < num_bins; bin++) { + BaseFloat left_mel = mel_low_freq + bin * mel_freq_delta, + center_mel = mel_low_freq + (bin + 1) * mel_freq_delta, + right_mel = mel_low_freq + (bin + 2) * mel_freq_delta; + + if (vtln_warp_factor != 1.0) { + left_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq, + vtln_warp_factor, left_mel); + center_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq, + vtln_warp_factor, center_mel); + right_mel = VtlnWarpMelFreq(vtln_low, vtln_high, low_freq, high_freq, + vtln_warp_factor, right_mel); + } + center_freqs_(bin) = InverseMelScale(center_mel); + // this_bin will be a vector of coefficients that is only + // nonzero where this mel bin is active. + Vector this_bin(num_fft_bins); + int32 first_index = -1, last_index = -1; + for (int32 i = 0; i < num_fft_bins; i++) { + BaseFloat freq = (fft_bin_width * i); // Center frequency of this fft + // bin. + BaseFloat mel = MelScale(freq); + if (mel > left_mel && mel < right_mel) { + BaseFloat weight; + if (mel <= center_mel) + weight = (mel - left_mel) / (center_mel - left_mel); + else + weight = (right_mel-mel) / (right_mel-center_mel); + this_bin(i) = weight; + if (first_index == -1) + first_index = i; + last_index = i; + } + } + KALDI_ASSERT(first_index != -1 && last_index >= first_index + && "You may have set --num-mel-bins too large."); + + bins_[bin].first = first_index; + int32 size = last_index + 1 - first_index; + bins_[bin].second.Resize(size); + bins_[bin].second.CopyFromVec(this_bin.Range(first_index, size)); + + // Replicate a bug in HTK, for testing purposes. + if (opts.htk_mode && bin == 0 && mel_low_freq != 0.0) + bins_[bin].second(0) = 0.0; + + } + if (debug_) { + for (size_t i = 0; i < bins_.size(); i++) { + KALDI_LOG << "bin " << i << ", offset = " << bins_[i].first + << ", vec = " << bins_[i].second; + } + } +} + +MelBanks::MelBanks(const MelBanks &other): + center_freqs_(other.center_freqs_), + bins_(other.bins_), + debug_(other.debug_), + htk_mode_(other.htk_mode_) { } + +BaseFloat MelBanks::VtlnWarpFreq(BaseFloat vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. + BaseFloat vtln_high_cutoff, + BaseFloat low_freq, // upper+lower frequency cutoffs in mel computation + BaseFloat high_freq, + BaseFloat vtln_warp_factor, + BaseFloat freq) { + /// This computes a VTLN warping function that is not the same as HTK's one, + /// but has similar inputs (this function has the advantage of never producing + /// empty bins). + + /// This function computes a warp function F(freq), defined between low_freq and + /// high_freq inclusive, with the following properties: + /// F(low_freq) == low_freq + /// F(high_freq) == high_freq + /// The function is continuous and piecewise linear with two inflection + /// points. + /// The lower inflection point (measured in terms of the unwarped + /// frequency) is at frequency l, determined as described below. + /// The higher inflection point is at a frequency h, determined as + /// described below. + /// If l <= f <= h, then F(f) = f/vtln_warp_factor. + /// If the higher inflection point (measured in terms of the unwarped + /// frequency) is at h, then max(h, F(h)) == vtln_high_cutoff. + /// Since (by the last point) F(h) == h/vtln_warp_factor, then + /// max(h, h/vtln_warp_factor) == vtln_high_cutoff, so + /// h = vtln_high_cutoff / max(1, 1/vtln_warp_factor). + /// = vtln_high_cutoff * min(1, vtln_warp_factor). + /// If the lower inflection point (measured in terms of the unwarped + /// frequency) is at l, then min(l, F(l)) == vtln_low_cutoff + /// This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor) + /// = vtln_low_cutoff * max(1, vtln_warp_factor) + + + if (freq < low_freq || freq > high_freq) return freq; // in case this gets called + // for out-of-range frequencies, just return the freq. + + KALDI_ASSERT(vtln_low_cutoff > low_freq && + "be sure to set the --vtln-low option higher than --low-freq"); + KALDI_ASSERT(vtln_high_cutoff < high_freq && + "be sure to set the --vtln-high option lower than --high-freq [or negative]"); + BaseFloat one = 1.0; + BaseFloat l = vtln_low_cutoff * std::max(one, vtln_warp_factor); + BaseFloat h = vtln_high_cutoff * std::min(one, vtln_warp_factor); + BaseFloat scale = 1.0 / vtln_warp_factor; + BaseFloat Fl = scale * l; // F(l); + BaseFloat Fh = scale * h; // F(h); + KALDI_ASSERT(l > low_freq && h < high_freq); + // slope of left part of the 3-piece linear function + BaseFloat scale_left = (Fl - low_freq) / (l - low_freq); + // [slope of center part is just "scale"] + + // slope of right part of the 3-piece linear function + BaseFloat scale_right = (high_freq - Fh) / (high_freq - h); + + if (freq < l) { + return low_freq + scale_left * (freq - low_freq); + } else if (freq < h) { + return scale * freq; + } else { // freq >= h + return high_freq + scale_right * (freq - high_freq); + } +} + +BaseFloat MelBanks::VtlnWarpMelFreq(BaseFloat vtln_low_cutoff, // upper+lower frequency cutoffs for VTLN. + BaseFloat vtln_high_cutoff, + BaseFloat low_freq, // upper+lower frequency cutoffs in mel computation + BaseFloat high_freq, + BaseFloat vtln_warp_factor, + BaseFloat mel_freq) { + return MelScale(VtlnWarpFreq(vtln_low_cutoff, vtln_high_cutoff, + low_freq, high_freq, + vtln_warp_factor, InverseMelScale(mel_freq))); +} + + +// "power_spectrum" contains fft energies. +void MelBanks::Compute(const VectorBase &power_spectrum, + VectorBase *mel_energies_out) const { + int32 num_bins = bins_.size(); + KALDI_ASSERT(mel_energies_out->Dim() == num_bins); + + for (int32 i = 0; i < num_bins; i++) { + int32 offset = bins_[i].first; + const Vector &v(bins_[i].second); + BaseFloat energy = VecVec(v, power_spectrum.Range(offset, v.Dim())); + // HTK-like flooring- for testing purposes (we prefer dither) + if (htk_mode_ && energy < 1.0) energy = 1.0; + (*mel_energies_out)(i) = energy; + + // The following assert was added due to a problem with OpenBlas that + // we had at one point (it was a bug in that library). Just to detect + // it early. + KALDI_ASSERT(!KALDI_ISNAN((*mel_energies_out)(i))); + } + + if (debug_) { + fprintf(stderr, "MEL BANKS:\n"); + for (int32 i = 0; i < num_bins; i++) + fprintf(stderr, " %f", (*mel_energies_out)(i)); + fprintf(stderr, "\n"); + } +} + +void ComputeLifterCoeffs(BaseFloat Q, VectorBase *coeffs) { + // Compute liftering coefficients (scaling on cepstral coeffs) + // coeffs are numbered slightly differently from HTK: the zeroth + // index is C0, which is not affected. + for (int32 i = 0; i < coeffs->Dim(); i++) + (*coeffs)(i) = 1.0 + 0.5 * Q * sin (M_PI * i / Q); +} + + +// Durbin's recursion - converts autocorrelation coefficients to the LPC +// pTmp - temporal place [n] +// pAC - autocorrelation coefficients [n + 1] +// pLP - linear prediction coefficients [n] (predicted_sn = sum_1^P{a[i-1] * s[n-i]}}) +// F(z) = 1 / (1 - A(z)), 1 is not stored in the demoninator +BaseFloat Durbin(int n, const BaseFloat *pAC, BaseFloat *pLP, BaseFloat *pTmp) { + BaseFloat ki; // reflection coefficient + int i; + int j; + + BaseFloat E = pAC[0]; + + for (i = 0; i < n; i++) { + // next reflection coefficient + ki = pAC[i + 1]; + for (j = 0; j < i; j++) + ki += pLP[j] * pAC[i - j]; + ki = ki / E; + + // new error + BaseFloat c = 1 - ki * ki; + if (c < 1.0e-5) // remove NaNs for constan signal + c = 1.0e-5; + E *= c; + + // new LP coefficients + pTmp[i] = -ki; + for (j = 0; j < i; j++) + pTmp[j] = pLP[j] - ki * pLP[i - j - 1]; + + for (j = 0; j <= i; j++) + pLP[j] = pTmp[j]; + } + + return E; +} + + +void Lpc2Cepstrum(int n, const BaseFloat *pLPC, BaseFloat *pCepst) { + for (int32 i = 0; i < n; i++) { + double sum = 0.0; + int j; + for (j = 0; j < i; j++) { + sum += static_cast(i - j) * pLPC[j] * pCepst[i - j - 1]; + } + pCepst[i] = -pLPC[i] - sum / static_cast(i + 1); + } +} + +void GetEqualLoudnessVector(const MelBanks &mel_banks, + Vector *ans) { + int32 n = mel_banks.NumBins(); + // Central frequency of each mel bin. + const Vector &f0 = mel_banks.GetCenterFreqs(); + ans->Resize(n); + for (int32 i = 0; i < n; i++) { + BaseFloat fsq = f0(i) * f0(i); + BaseFloat fsub = fsq / (fsq + 1.6e5); + (*ans)(i) = fsub * fsub * ((fsq + 1.44e6) / (fsq + 9.61e6)); + } +} + + +// Compute LP coefficients from autocorrelation coefficients. +BaseFloat ComputeLpc(const VectorBase &autocorr_in, + Vector *lpc_out) { + int32 n = autocorr_in.Dim() - 1; + KALDI_ASSERT(lpc_out->Dim() == n); + Vector tmp(n); + BaseFloat ans = Durbin(n, autocorr_in.Data(), + lpc_out->Data(), + tmp.Data()); + if (ans <= 0.0) + KALDI_WARN << "Zero energy in LPC computation"; + return -Log(1.0 / ans); // forms the C0 value +} + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/mel-computations.h b/speechx/speechx/kaldi/feat/mel-computations.h new file mode 100644 index 00000000..0c1d41ca --- /dev/null +++ b/speechx/speechx/kaldi/feat/mel-computations.h @@ -0,0 +1,171 @@ +// feat/mel-computations.h + +// Copyright 2009-2011 Phonexia s.r.o.; Microsoft Corporation +// 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_MEL_COMPUTATIONS_H_ +#define KALDI_FEAT_MEL_COMPUTATIONS_H_ + +#include +#include +#include +#include +#include +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/matrix-lib.h" + + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + +struct FrameExtractionOptions; // defined in feature-window.h + + +struct MelBanksOptions { + int32 num_bins; // e.g. 25; number of triangular bins + BaseFloat low_freq; // e.g. 20; lower frequency cutoff + BaseFloat high_freq; // an upper frequency cutoff; 0 -> no cutoff, negative + // ->added to the Nyquist frequency to get the cutoff. + BaseFloat vtln_low; // vtln lower cutoff of warping function. + BaseFloat vtln_high; // vtln upper cutoff of warping function: if negative, added + // to the Nyquist frequency to get the cutoff. + bool debug_mel; + // htk_mode is a "hidden" config, it does not show up on command line. + // Enables more exact compatibility with HTK, for testing purposes. Affects + // mel-energy flooring and reproduces a bug in HTK. + bool htk_mode; + explicit MelBanksOptions(int num_bins = 25) + : num_bins(num_bins), low_freq(20), high_freq(0), vtln_low(100), + vtln_high(-500), debug_mel(false), htk_mode(false) {} + + void Register(OptionsItf *opts) { + opts->Register("num-mel-bins", &num_bins, + "Number of triangular mel-frequency bins"); + opts->Register("low-freq", &low_freq, + "Low cutoff frequency for mel bins"); + opts->Register("high-freq", &high_freq, + "High cutoff frequency for mel bins (if <= 0, offset from Nyquist)"); + opts->Register("vtln-low", &vtln_low, + "Low inflection point in piecewise linear VTLN warping function"); + opts->Register("vtln-high", &vtln_high, + "High inflection point in piecewise linear VTLN warping function" + " (if negative, offset from high-mel-freq"); + opts->Register("debug-mel", &debug_mel, + "Print out debugging information for mel bin computation"); + } +}; + + +class MelBanks { + public: + + static inline BaseFloat InverseMelScale(BaseFloat mel_freq) { + return 700.0f * (expf (mel_freq / 1127.0f) - 1.0f); + } + + static inline BaseFloat MelScale(BaseFloat freq) { + return 1127.0f * logf (1.0f + freq / 700.0f); + } + + static BaseFloat VtlnWarpFreq(BaseFloat vtln_low_cutoff, + BaseFloat vtln_high_cutoff, // discontinuities in warp func + BaseFloat low_freq, + BaseFloat high_freq, // upper+lower frequency cutoffs in + // the mel computation + BaseFloat vtln_warp_factor, + BaseFloat freq); + + static BaseFloat VtlnWarpMelFreq(BaseFloat vtln_low_cutoff, + BaseFloat vtln_high_cutoff, + BaseFloat low_freq, + BaseFloat high_freq, + BaseFloat vtln_warp_factor, + BaseFloat mel_freq); + + + MelBanks(const MelBanksOptions &opts, + const FrameExtractionOptions &frame_opts, + BaseFloat vtln_warp_factor); + + /// Compute Mel energies (note: not log enerties). + /// At input, "fft_energies" contains the FFT energies (not log). + void Compute(const VectorBase &fft_energies, + VectorBase *mel_energies_out) const; + + int32 NumBins() const { return bins_.size(); } + + // returns vector of central freq of each bin; needed by plp code. + const Vector &GetCenterFreqs() const { return center_freqs_; } + + const std::vector > >& GetBins() const { + return bins_; + } + + // Copy constructor + MelBanks(const MelBanks &other); + private: + // Disallow assignment + MelBanks &operator = (const MelBanks &other); + + // center frequencies of bins, numbered from 0 ... num_bins-1. + // Needed by GetCenterFreqs(). + Vector center_freqs_; + + // the "bins_" vector is a vector, one for each bin, of a pair: + // (the first nonzero fft-bin), (the vector of weights). + std::vector > > bins_; + + bool debug_; + bool htk_mode_; +}; + + +// Compute liftering coefficients (scaling on cepstral coeffs) +// coeffs are numbered slightly differently from HTK: the zeroth +// index is C0, which is not affected. +void ComputeLifterCoeffs(BaseFloat Q, VectorBase *coeffs); + + +// Durbin's recursion - converts autocorrelation coefficients to the LPC +// pTmp - temporal place [n] +// pAC - autocorrelation coefficients [n + 1] +// pLP - linear prediction coefficients [n] (predicted_sn = sum_1^P{a[i-1] * s[n-i]}}) +// F(z) = 1 / (1 - A(z)), 1 is not stored in the denominator +// Returns log energy of residual (I think) +BaseFloat Durbin(int n, const BaseFloat *pAC, BaseFloat *pLP, BaseFloat *pTmp); + +// Compute LP coefficients from autocorrelation coefficients. +// Returns log energy of residual (I think) +BaseFloat ComputeLpc(const VectorBase &autocorr_in, + Vector *lpc_out); + +void Lpc2Cepstrum(int n, const BaseFloat *pLPC, BaseFloat *pCepst); + + + +void GetEqualLoudnessVector(const MelBanks &mel_banks, + Vector *ans); + +/// @} End of "addtogroup feat" +} // namespace kaldi + +#endif // KALDI_FEAT_MEL_COMPUTATIONS_H_ diff --git a/speechx/speechx/kaldi/feat/online-feature.cc b/speechx/speechx/kaldi/feat/online-feature.cc new file mode 100644 index 00000000..047909e7 --- /dev/null +++ b/speechx/speechx/kaldi/feat/online-feature.cc @@ -0,0 +1,679 @@ +// feat/online-feature.cc + +// Copyright 2013 Johns Hopkins University (author: Daniel Povey) +// 2014 Yanqing Sun, Junjie Wang, +// Daniel Povey, Korbinian Riedhammer + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "feat/online-feature.h" +#include "transform/cmvn.h" + +namespace kaldi { + +RecyclingVector::RecyclingVector(int items_to_hold): + items_to_hold_(items_to_hold == 0 ? -1 : items_to_hold), + first_available_index_(0) { +} + +RecyclingVector::~RecyclingVector() { + for (auto *item : items_) { + delete item; + } +} + +Vector *RecyclingVector::At(int index) const { + if (index < first_available_index_) { + KALDI_ERR << "Attempted to retrieve feature vector that was " + "already removed by the RecyclingVector (index = " + << index << "; " + << "first_available_index = " << first_available_index_ << "; " + << "size = " << Size() << ")"; + } + // 'at' does size checking. + return items_.at(index - first_available_index_); +} + +void RecyclingVector::PushBack(Vector *item) { + if (items_.size() == items_to_hold_) { + delete items_.front(); + items_.pop_front(); + ++first_available_index_; + } + items_.push_back(item); +} + +int RecyclingVector::Size() const { + return first_available_index_ + items_.size(); +} + +template +void OnlineGenericBaseFeature::GetFrame(int32 frame, + VectorBase *feat) { + feat->CopyFromVec(*(features_.At(frame))); +}; + +template +OnlineGenericBaseFeature::OnlineGenericBaseFeature( + const typename C::Options &opts): + computer_(opts), window_function_(computer_.GetFrameOptions()), + features_(opts.frame_opts.max_feature_vectors), + input_finished_(false), waveform_offset_(0) { + // RE the following assert: search for ONLINE_IVECTOR_LIMIT in + // online-ivector-feature.cc. + // Casting to uint32, an unsigned type, means that -1 would be treated + // as `very large`. + KALDI_ASSERT(static_cast(opts.frame_opts.max_feature_vectors) > 200); +} + + +template +void OnlineGenericBaseFeature::MaybeCreateResampler( + BaseFloat sampling_rate) { + BaseFloat expected_sampling_rate = computer_.GetFrameOptions().samp_freq; + + if (resampler_ != nullptr) { + KALDI_ASSERT(resampler_->GetInputSamplingRate() == sampling_rate); + KALDI_ASSERT(resampler_->GetOutputSamplingRate() == expected_sampling_rate); + } else if (((sampling_rate < expected_sampling_rate) && + computer_.GetFrameOptions().allow_downsample) || + ((sampling_rate > expected_sampling_rate) && + computer_.GetFrameOptions().allow_upsample)) { + resampler_.reset(new LinearResample( + sampling_rate, expected_sampling_rate, + std::min(sampling_rate / 2, expected_sampling_rate / 2), 6)); + } else if (sampling_rate != expected_sampling_rate) { + KALDI_ERR << "Sampling frequency mismatch, expected " + << expected_sampling_rate << ", got " << sampling_rate + << "\nPerhaps you want to use the options " + "--allow_{upsample,downsample}"; + } +} + +template +void OnlineGenericBaseFeature::InputFinished() { + if (resampler_ != nullptr) { + // There may be a few samples left once we flush the resampler_ object, telling it + // that the file has finished. This should rarely make any difference. + Vector appended_wave; + Vector resampled_wave; + resampler_->Resample(appended_wave, true, &resampled_wave); + + if (resampled_wave.Dim() != 0) { + appended_wave.Resize(waveform_remainder_.Dim() + + resampled_wave.Dim()); + if (waveform_remainder_.Dim() != 0) + appended_wave.Range(0, waveform_remainder_.Dim()) + .CopyFromVec(waveform_remainder_); + appended_wave.Range(waveform_remainder_.Dim(), resampled_wave.Dim()) + .CopyFromVec(resampled_wave); + waveform_remainder_.Swap(&appended_wave); + } + } + input_finished_ = true; + ComputeFeatures(); +} + +template +void OnlineGenericBaseFeature::AcceptWaveform( + BaseFloat sampling_rate, const VectorBase &original_waveform) { + if (original_waveform.Dim() == 0) + return; // Nothing to do. + if (input_finished_) + KALDI_ERR << "AcceptWaveform called after InputFinished() was called."; + + Vector appended_wave; + Vector resampled_wave; + + const VectorBase *waveform; + + MaybeCreateResampler(sampling_rate); + if (resampler_ == nullptr) { + waveform = &original_waveform; + } else { + resampler_->Resample(original_waveform, false, &resampled_wave); + waveform = &resampled_wave; + } + + appended_wave.Resize(waveform_remainder_.Dim() + waveform->Dim()); + if (waveform_remainder_.Dim() != 0) + appended_wave.Range(0, waveform_remainder_.Dim()) + .CopyFromVec(waveform_remainder_); + appended_wave.Range(waveform_remainder_.Dim(), waveform->Dim()) + .CopyFromVec(*waveform); + waveform_remainder_.Swap(&appended_wave); + ComputeFeatures(); +} + +template +void OnlineGenericBaseFeature::ComputeFeatures() { + const FrameExtractionOptions &frame_opts = computer_.GetFrameOptions(); + int64 num_samples_total = waveform_offset_ + waveform_remainder_.Dim(); + int32 num_frames_old = features_.Size(), + num_frames_new = NumFrames(num_samples_total, frame_opts, + input_finished_); + KALDI_ASSERT(num_frames_new >= num_frames_old); + + Vector window; + bool need_raw_log_energy = computer_.NeedRawLogEnergy(); + for (int32 frame = num_frames_old; frame < num_frames_new; frame++) { + BaseFloat raw_log_energy = 0.0; + ExtractWindow(waveform_offset_, waveform_remainder_, frame, + frame_opts, window_function_, &window, + need_raw_log_energy ? &raw_log_energy : NULL); + Vector *this_feature = new Vector(computer_.Dim(), + kUndefined); + // note: this online feature-extraction code does not support VTLN. + BaseFloat vtln_warp = 1.0; + computer_.Compute(raw_log_energy, vtln_warp, &window, this_feature); + features_.PushBack(this_feature); + } + // OK, we will now discard any portion of the signal that will not be + // necessary to compute frames in the future. + int64 first_sample_of_next_frame = FirstSampleOfFrame(num_frames_new, + frame_opts); + int32 samples_to_discard = first_sample_of_next_frame - waveform_offset_; + if (samples_to_discard > 0) { + // discard the leftmost part of the waveform that we no longer need. + int32 new_num_samples = waveform_remainder_.Dim() - samples_to_discard; + if (new_num_samples <= 0) { + // odd, but we'll try to handle it. + waveform_offset_ += waveform_remainder_.Dim(); + waveform_remainder_.Resize(0); + } else { + Vector new_remainder(new_num_samples); + new_remainder.CopyFromVec(waveform_remainder_.Range(samples_to_discard, + new_num_samples)); + waveform_offset_ += samples_to_discard; + waveform_remainder_.Swap(&new_remainder); + } + } +} + +// instantiate the templates defined here for MFCC, PLP and filterbank classes. +template class OnlineGenericBaseFeature; +template class OnlineGenericBaseFeature; +template class OnlineGenericBaseFeature; + +OnlineCmvnState::OnlineCmvnState(const OnlineCmvnState &other): + speaker_cmvn_stats(other.speaker_cmvn_stats), + global_cmvn_stats(other.global_cmvn_stats), + frozen_state(other.frozen_state) { } + +void OnlineCmvnState::Write(std::ostream &os, bool binary) const { + WriteToken(os, binary, ""); // magic string. + WriteToken(os, binary, ""); + speaker_cmvn_stats.Write(os, binary); + WriteToken(os, binary, ""); + global_cmvn_stats.Write(os, binary); + WriteToken(os, binary, ""); + frozen_state.Write(os, binary); + WriteToken(os, binary, ""); +} + +void OnlineCmvnState::Read(std::istream &is, bool binary) { + ExpectToken(is, binary, ""); // magic string. + ExpectToken(is, binary, ""); + speaker_cmvn_stats.Read(is, binary); + ExpectToken(is, binary, ""); + global_cmvn_stats.Read(is, binary); + ExpectToken(is, binary, ""); + frozen_state.Read(is, binary); + ExpectToken(is, binary, ""); +} + +OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, + const OnlineCmvnState &cmvn_state, + OnlineFeatureInterface *src): + opts_(opts), temp_stats_(2, src->Dim() + 1), + temp_feats_(src->Dim()), temp_feats_dbl_(src->Dim()), + src_(src) { + SetState(cmvn_state); + if (!SplitStringToIntegers(opts.skip_dims, ":", false, &skip_dims_)) + KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " + << "integers)"; +} + +OnlineCmvn::OnlineCmvn(const OnlineCmvnOptions &opts, + OnlineFeatureInterface *src): + opts_(opts), temp_stats_(2, src->Dim() + 1), + temp_feats_(src->Dim()), temp_feats_dbl_(src->Dim()), + src_(src) { + if (!SplitStringToIntegers(opts.skip_dims, ":", false, &skip_dims_)) + KALDI_ERR << "Bad --skip-dims option (should be colon-separated list of " + << "integers)"; +} + + +void OnlineCmvn::GetMostRecentCachedFrame(int32 frame, + int32 *cached_frame, + MatrixBase *stats) { + KALDI_ASSERT(frame >= 0); + InitRingBufferIfNeeded(); + // look for a cached frame on a previous frame as close as possible in time + // to "frame". Return if we get one. + for (int32 t = frame; t >= 0 && t >= frame - opts_.ring_buffer_size; t--) { + if (t % opts_.modulus == 0) { + // if this frame should be cached in cached_stats_modulo_, then + // we'll look there, and we won't go back any further in time. + break; + } + int32 index = t % opts_.ring_buffer_size; + if (cached_stats_ring_[index].first == t) { + *cached_frame = t; + stats->CopyFromMat(cached_stats_ring_[index].second); + return; + } + } + int32 n = frame / opts_.modulus; + if (n >= cached_stats_modulo_.size()) { + if (cached_stats_modulo_.size() == 0) { + *cached_frame = -1; + stats->SetZero(); + return; + } else { + n = static_cast(cached_stats_modulo_.size() - 1); + } + } + *cached_frame = n * opts_.modulus; + KALDI_ASSERT(cached_stats_modulo_[n] != NULL); + stats->CopyFromMat(*(cached_stats_modulo_[n])); +} + +// Initialize ring buffer for caching stats. +void OnlineCmvn::InitRingBufferIfNeeded() { + if (cached_stats_ring_.empty() && opts_.ring_buffer_size > 0) { + Matrix temp(2, this->Dim() + 1); + cached_stats_ring_.resize(opts_.ring_buffer_size, + std::pair >(-1, temp)); + } +} + +void OnlineCmvn::CacheFrame(int32 frame, const MatrixBase &stats) { + KALDI_ASSERT(frame >= 0); + if (frame % opts_.modulus == 0) { // store in cached_stats_modulo_. + int32 n = frame / opts_.modulus; + if (n >= cached_stats_modulo_.size()) { + // The following assert is a limitation on in what order you can call + // CacheFrame. Fortunately the calling code always calls it in sequence, + // which it has to because you need a previous frame to compute the + // current one. + KALDI_ASSERT(n == cached_stats_modulo_.size()); + cached_stats_modulo_.push_back(new Matrix(stats)); + } else { + KALDI_WARN << "Did not expect to reach this part of code."; + // do what seems right, but we shouldn't get here. + cached_stats_modulo_[n]->CopyFromMat(stats); + } + } else { // store in the ring buffer. + InitRingBufferIfNeeded(); + if (!cached_stats_ring_.empty()) { + int32 index = frame % cached_stats_ring_.size(); + cached_stats_ring_[index].first = frame; + cached_stats_ring_[index].second.CopyFromMat(stats); + } + } +} + +OnlineCmvn::~OnlineCmvn() { + for (size_t i = 0; i < cached_stats_modulo_.size(); i++) + delete cached_stats_modulo_[i]; + cached_stats_modulo_.clear(); +} + +void OnlineCmvn::ComputeStatsForFrame(int32 frame, + MatrixBase *stats_out) { + KALDI_ASSERT(frame >= 0 && frame < src_->NumFramesReady()); + + int32 dim = this->Dim(), cur_frame; + GetMostRecentCachedFrame(frame, &cur_frame, stats_out); + + Vector &feats(temp_feats_); + Vector &feats_dbl(temp_feats_dbl_); + while (cur_frame < frame) { + cur_frame++; + src_->GetFrame(cur_frame, &feats); + feats_dbl.CopyFromVec(feats); + stats_out->Row(0).Range(0, dim).AddVec(1.0, feats_dbl); + if (opts_.normalize_variance) + stats_out->Row(1).Range(0, dim).AddVec2(1.0, feats_dbl); + (*stats_out)(0, dim) += 1.0; + // it's a sliding buffer; a frame at the back may be + // leaving the buffer so we have to subtract that. + int32 prev_frame = cur_frame - opts_.cmn_window; + if (prev_frame >= 0) { + // we need to subtract frame prev_f from the stats. + src_->GetFrame(prev_frame, &feats); + feats_dbl.CopyFromVec(feats); + stats_out->Row(0).Range(0, dim).AddVec(-1.0, feats_dbl); + if (opts_.normalize_variance) + stats_out->Row(1).Range(0, dim).AddVec2(-1.0, feats_dbl); + (*stats_out)(0, dim) -= 1.0; + } + CacheFrame(cur_frame, (*stats_out)); + } +} + + +// static +void OnlineCmvn::SmoothOnlineCmvnStats(const MatrixBase &speaker_stats, + const MatrixBase &global_stats, + const OnlineCmvnOptions &opts, + MatrixBase *stats) { + if (speaker_stats.NumRows() == 2 && !opts.normalize_variance) { + // this is just for efficiency: don't operate on the variance if it's not + // needed. + int32 cols = speaker_stats.NumCols(); // dim + 1 + SubMatrix stats_temp(*stats, 0, 1, 0, cols); + SmoothOnlineCmvnStats(speaker_stats.RowRange(0, 1), + global_stats.RowRange(0, 1), + opts, &stats_temp); + return; + } + int32 dim = stats->NumCols() - 1; + double cur_count = (*stats)(0, dim); + // If count exceeded cmn_window it would be an error in how "window_stats" + // was accumulated. + KALDI_ASSERT(cur_count <= 1.001 * opts.cmn_window); + if (cur_count >= opts.cmn_window) + return; + if (speaker_stats.NumRows() != 0) { // if we have speaker stats.. + double count_from_speaker = opts.cmn_window - cur_count, + speaker_count = speaker_stats(0, dim); + if (count_from_speaker > opts.speaker_frames) + count_from_speaker = opts.speaker_frames; + if (count_from_speaker > speaker_count) + count_from_speaker = speaker_count; + if (count_from_speaker > 0.0) + stats->AddMat(count_from_speaker / speaker_count, + speaker_stats); + cur_count = (*stats)(0, dim); + } + if (cur_count >= opts.cmn_window) + return; + if (global_stats.NumRows() != 0) { + double count_from_global = opts.cmn_window - cur_count, + global_count = global_stats(0, dim); + KALDI_ASSERT(global_count > 0.0); + if (count_from_global > opts.global_frames) + count_from_global = opts.global_frames; + if (count_from_global > 0.0) + stats->AddMat(count_from_global / global_count, + global_stats); + } else { + KALDI_ERR << "Global CMN stats are required"; + } +} + +void OnlineCmvn::GetFrame(int32 frame, + VectorBase *feat) { + src_->GetFrame(frame, feat); + KALDI_ASSERT(feat->Dim() == this->Dim()); + int32 dim = feat->Dim(); + Matrix &stats(temp_stats_); + stats.Resize(2, dim + 1, kUndefined); // Will do nothing if size was correct. + if (frozen_state_.NumRows() != 0) { // the CMVN state has been frozen. + stats.CopyFromMat(frozen_state_); + } else { + // first get the raw CMVN stats (this involves caching..) + this->ComputeStatsForFrame(frame, &stats); + // now smooth them. + SmoothOnlineCmvnStats(orig_state_.speaker_cmvn_stats, + orig_state_.global_cmvn_stats, + opts_, + &stats); + } + + if (!skip_dims_.empty()) + FakeStatsForSomeDims(skip_dims_, &stats); + + // call the function ApplyCmvn declared in ../transform/cmvn.h, which + // requires a matrix. + // 1 row; num-cols == dim; stride == dim. + SubMatrix feat_mat(feat->Data(), 1, dim, dim); + // the function ApplyCmvn takes a matrix, so form a one-row matrix to give it. + if (opts_.normalize_mean) + ApplyCmvn(stats, opts_.normalize_variance, &feat_mat); + else + KALDI_ASSERT(!opts_.normalize_variance); +} + +void OnlineCmvn::Freeze(int32 cur_frame) { + int32 dim = this->Dim(); + Matrix stats(2, dim + 1); + // get the raw CMVN stats + this->ComputeStatsForFrame(cur_frame, &stats); + // now smooth them. + SmoothOnlineCmvnStats(orig_state_.speaker_cmvn_stats, + orig_state_.global_cmvn_stats, + opts_, + &stats); + this->frozen_state_ = stats; +} + +void OnlineCmvn::GetState(int32 cur_frame, + OnlineCmvnState *state_out) { + *state_out = this->orig_state_; + { // This block updates state_out->speaker_cmvn_stats + int32 dim = this->Dim(); + if (state_out->speaker_cmvn_stats.NumRows() == 0) + state_out->speaker_cmvn_stats.Resize(2, dim + 1); + Vector feat(dim); + Vector feat_dbl(dim); + for (int32 t = 0; t <= cur_frame; t++) { + src_->GetFrame(t, &feat); + feat_dbl.CopyFromVec(feat); + state_out->speaker_cmvn_stats(0, dim) += 1.0; + state_out->speaker_cmvn_stats.Row(0).Range(0, dim).AddVec(1.0, feat_dbl); + state_out->speaker_cmvn_stats.Row(1).Range(0, dim).AddVec2(1.0, feat_dbl); + } + } + // Store any frozen state (the effect of the user possibly + // having called Freeze(). + state_out->frozen_state = frozen_state_; +} + +void OnlineCmvn::SetState(const OnlineCmvnState &cmvn_state) { + KALDI_ASSERT(cached_stats_modulo_.empty() && + "You cannot call SetState() after processing data."); + orig_state_ = cmvn_state; + frozen_state_ = cmvn_state.frozen_state; +} + +int32 OnlineSpliceFrames::NumFramesReady() const { + int32 num_frames = src_->NumFramesReady(); + if (num_frames > 0 && src_->IsLastFrame(num_frames - 1)) + return num_frames; + else + return std::max(0, num_frames - right_context_); +} + +void OnlineSpliceFrames::GetFrame(int32 frame, VectorBase *feat) { + KALDI_ASSERT(left_context_ >= 0 && right_context_ >= 0); + KALDI_ASSERT(frame >= 0 && frame < NumFramesReady()); + int32 dim_in = src_->Dim(); + KALDI_ASSERT(feat->Dim() == dim_in * (1 + left_context_ + right_context_)); + int32 T = src_->NumFramesReady(); + for (int32 t2 = frame - left_context_; t2 <= frame + right_context_; t2++) { + int32 t2_limited = t2; + if (t2_limited < 0) t2_limited = 0; + if (t2_limited >= T) t2_limited = T - 1; + int32 n = t2 - (frame - left_context_); // 0 for left-most frame, + // increases to the right. + SubVector part(*feat, n * dim_in, dim_in); + src_->GetFrame(t2_limited, &part); + } +} + +OnlineTransform::OnlineTransform(const MatrixBase &transform, + OnlineFeatureInterface *src): + src_(src) { + int32 src_dim = src_->Dim(); + if (transform.NumCols() == src_dim) { // Linear transform + linear_term_ = transform; + offset_.Resize(transform.NumRows()); // Resize() will zero it. + } else if (transform.NumCols() == src_dim + 1) { // Affine transform + linear_term_ = transform.Range(0, transform.NumRows(), 0, src_dim); + offset_.Resize(transform.NumRows()); + offset_.CopyColFromMat(transform, src_dim); + } else { + KALDI_ERR << "Dimension mismatch: source features have dimension " + << src_dim << " and LDA #cols is " << transform.NumCols(); + } +} + +void OnlineTransform::GetFrame(int32 frame, VectorBase *feat) { + Vector input_feat(linear_term_.NumCols()); + src_->GetFrame(frame, &input_feat); + feat->CopyFromVec(offset_); + feat->AddMatVec(1.0, linear_term_, kNoTrans, input_feat, 1.0); +} + +void OnlineTransform::GetFrames( + const std::vector &frames, MatrixBase *feats) { + KALDI_ASSERT(static_cast(frames.size()) == feats->NumRows()); + int32 num_frames = feats->NumRows(), + input_dim = linear_term_.NumCols(); + Matrix input_feats(num_frames, input_dim, kUndefined); + src_->GetFrames(frames, &input_feats); + feats->CopyRowsFromVec(offset_); + feats->AddMatMat(1.0, input_feats, kNoTrans, linear_term_, kTrans, 1.0); +} + + +int32 OnlineDeltaFeature::Dim() const { + int32 src_dim = src_->Dim(); + return src_dim * (1 + opts_.order); +} + +int32 OnlineDeltaFeature::NumFramesReady() const { + int32 num_frames = src_->NumFramesReady(), + context = opts_.order * opts_.window; + // "context" is the number of frames on the left or (more relevant + // here) right which we need in order to produce the output. + if (num_frames > 0 && src_->IsLastFrame(num_frames-1)) + return num_frames; + else + return std::max(0, num_frames - context); +} + +void OnlineDeltaFeature::GetFrame(int32 frame, + VectorBase *feat) { + KALDI_ASSERT(frame >= 0 && frame < NumFramesReady()); + KALDI_ASSERT(feat->Dim() == Dim()); + // We'll produce a temporary matrix containing the features we want to + // compute deltas on, but truncated to the necessary context. + int32 context = opts_.order * opts_.window; + int32 left_frame = frame - context, + right_frame = frame + context, + src_frames_ready = src_->NumFramesReady(); + if (left_frame < 0) left_frame = 0; + if (right_frame >= src_frames_ready) + right_frame = src_frames_ready - 1; + KALDI_ASSERT(right_frame >= left_frame); + int32 temp_num_frames = right_frame + 1 - left_frame, + src_dim = src_->Dim(); + Matrix temp_src(temp_num_frames, src_dim); + for (int32 t = left_frame; t <= right_frame; t++) { + SubVector temp_row(temp_src, t - left_frame); + src_->GetFrame(t, &temp_row); + } + int32 temp_t = frame - left_frame; // temp_t is the offset of frame "frame" + // within temp_src + delta_features_.Process(temp_src, temp_t, feat); +} + + +OnlineDeltaFeature::OnlineDeltaFeature(const DeltaFeaturesOptions &opts, + OnlineFeatureInterface *src): + src_(src), opts_(opts), delta_features_(opts) { } + +void OnlineCacheFeature::GetFrame(int32 frame, VectorBase *feat) { + KALDI_ASSERT(frame >= 0); + if (static_cast(frame) < cache_.size() && cache_[frame] != NULL) { + feat->CopyFromVec(*(cache_[frame])); + } else { + if (static_cast(frame) >= cache_.size()) + cache_.resize(frame + 1, NULL); + int32 dim = this->Dim(); + cache_[frame] = new Vector(dim); + // The following call will crash if frame "frame" is not ready. + src_->GetFrame(frame, cache_[frame]); + feat->CopyFromVec(*(cache_[frame])); + } +} + +void OnlineCacheFeature::GetFrames( + const std::vector &frames, MatrixBase *feats) { + int32 num_frames = frames.size(); + // non_cached_frames will be the subset of 't' values in 'frames' which were + // not previously cached, which we therefore need to get from src_. + std::vector non_cached_frames; + // 'non_cached_indexes' stores the indexes 'i' into 'frames' corresponding to + // the corresponding frames in 'non_cached_frames'. + std::vector non_cached_indexes; + non_cached_frames.reserve(frames.size()); + non_cached_indexes.reserve(frames.size()); + for (int32 i = 0; i < num_frames; i++) { + int32 t = frames[i]; + if (static_cast(t) < cache_.size() && cache_[t] != NULL) { + feats->Row(i).CopyFromVec(*(cache_[t])); + } else { + non_cached_frames.push_back(t); + non_cached_indexes.push_back(i); + } + } + if (non_cached_frames.empty()) + return; + int32 num_non_cached_frames = non_cached_frames.size(), + dim = this->Dim(); + Matrix non_cached_feats(num_non_cached_frames, dim, + kUndefined); + src_->GetFrames(non_cached_frames, &non_cached_feats); + for (int32 i = 0; i < num_non_cached_frames; i++) { + int32 t = non_cached_frames[i]; + if (static_cast(t) < cache_.size() && cache_[t] != NULL) { + // We can reach this point due to repeat indexes in 'non_cached_frames'. + feats->Row(non_cached_indexes[i]).CopyFromVec(*(cache_[t])); + } else { + SubVector this_feat(non_cached_feats, i); + feats->Row(non_cached_indexes[i]).CopyFromVec(this_feat); + if (static_cast(t) >= cache_.size()) + cache_.resize(t + 1, NULL); + cache_[t] = new Vector(this_feat); + } + } +} + + +void OnlineCacheFeature::ClearCache() { + for (size_t i = 0; i < cache_.size(); i++) + delete cache_[i]; + cache_.resize(0); +} + + +void OnlineAppendFeature::GetFrame(int32 frame, VectorBase *feat) { + KALDI_ASSERT(feat->Dim() == Dim()); + + SubVector feat1(*feat, 0, src1_->Dim()); + SubVector feat2(*feat, src1_->Dim(), src2_->Dim()); + src1_->GetFrame(frame, &feat1); + src2_->GetFrame(frame, &feat2); +}; + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/online-feature.h b/speechx/speechx/kaldi/feat/online-feature.h new file mode 100644 index 00000000..f2ebe45b --- /dev/null +++ b/speechx/speechx/kaldi/feat/online-feature.h @@ -0,0 +1,632 @@ +// feat/online-feature.h + +// Copyright 2013 Johns Hopkins University (author: Daniel Povey) +// 2014 Yanqing Sun, Junjie Wang, +// Daniel Povey, Korbinian Riedhammer + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_FEAT_ONLINE_FEATURE_H_ +#define KALDI_FEAT_ONLINE_FEATURE_H_ + +#include +#include +#include + +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" +#include "feat/feature-functions.h" +#include "feat/feature-mfcc.h" +#include "feat/feature-plp.h" +#include "feat/feature-fbank.h" +#include "itf/online-feature-itf.h" + +namespace kaldi { +/// @addtogroup onlinefeat OnlineFeatureExtraction +/// @{ + + +/// This class serves as a storage for feature vectors with an option to limit +/// the memory usage by removing old elements. The deleted frames indices are +/// "remembered" so that regardless of the MAX_ITEMS setting, the user always +/// provides the indices as if no deletion was being performed. +/// This is useful when processing very long recordings which would otherwise +/// cause the memory to eventually blow up when the features are not being removed. +class RecyclingVector { +public: + /// By default it does not remove any elements. + RecyclingVector(int items_to_hold = -1); + + /// The ownership is being retained by this collection - do not delete the item. + Vector *At(int index) const; + + /// The ownership of the item is passed to this collection - do not delete the item. + void PushBack(Vector *item); + + /// This method returns the size as if no "recycling" had happened, + /// i.e. equivalent to the number of times the PushBack method has been called. + int Size() const; + + ~RecyclingVector(); + +private: + std::deque*> items_; + int items_to_hold_; + int first_available_index_; +}; + + +/// This is a templated class for online feature extraction; +/// it's templated on a class like MfccComputer or PlpComputer +/// that does the basic feature extraction. +template +class OnlineGenericBaseFeature: public OnlineBaseFeature { + public: + // + // First, functions that are present in the interface: + // + virtual int32 Dim() const { return computer_.Dim(); } + + // Note: IsLastFrame() will only ever return true if you have called + // InputFinished() (and this frame is the last frame). + virtual bool IsLastFrame(int32 frame) const { + return input_finished_ && frame == NumFramesReady() - 1; + } + virtual BaseFloat FrameShiftInSeconds() const { + return computer_.GetFrameOptions().frame_shift_ms / 1000.0f; + } + + virtual int32 NumFramesReady() const { return features_.Size(); } + + virtual void GetFrame(int32 frame, VectorBase *feat); + + // Next, functions that are not in the interface. + + + // Constructor from options class + explicit OnlineGenericBaseFeature(const typename C::Options &opts); + + // This would be called from the application, when you get + // more wave data. Note: the sampling_rate is only provided so + // the code can assert that it matches the sampling rate + // expected in the options. + virtual void AcceptWaveform(BaseFloat sampling_rate, + const VectorBase &waveform); + + + // InputFinished() tells the class you won't be providing any + // more waveform. This will help flush out the last frame or two + // of features, in the case where snip-edges == false; it also + // affects the return value of IsLastFrame(). + virtual void InputFinished(); + + private: + // This function computes any additional feature frames that it is possible to + // compute from 'waveform_remainder_', which at this point may contain more + // than just a remainder-sized quantity (because AcceptWaveform() appends to + // waveform_remainder_ before calling this function). It adds these feature + // frames to features_, and shifts off any now-unneeded samples of input from + // waveform_remainder_ while incrementing waveform_offset_ by the same amount. + void ComputeFeatures(); + + void MaybeCreateResampler(BaseFloat sampling_rate); + + C computer_; // class that does the MFCC or PLP or filterbank computation + + // resampler in cases when the input sampling frequency is not equal to + // the expected sampling rate + std::unique_ptr resampler_; + + FeatureWindowFunction window_function_; + + // features_ is the Mfcc or Plp or Fbank features that we have already computed. + + RecyclingVector features_; + + // True if the user has called "InputFinished()" + bool input_finished_; + + // The sampling frequency, extracted from the config. Should + // be identical to the waveform supplied. + BaseFloat sampling_frequency_; + + // waveform_offset_ is the number of samples of waveform that we have + // already discarded, i.e. that were prior to 'waveform_remainder_'. + int64 waveform_offset_; + + // waveform_remainder_ is a short piece of waveform that we may need to keep + // after extracting all the whole frames we can (whatever length of feature + // will be required for the next phase of computation). + Vector waveform_remainder_; +}; + +typedef OnlineGenericBaseFeature OnlineMfcc; +typedef OnlineGenericBaseFeature OnlinePlp; +typedef OnlineGenericBaseFeature OnlineFbank; + + +/// This class takes a Matrix and wraps it as an +/// OnlineFeatureInterface: this can be useful where some earlier stage of +/// feature processing has been done offline but you want to use part of the +/// online pipeline. +class OnlineMatrixFeature: public OnlineFeatureInterface { + public: + /// Caution: this class maintains the const reference from the constructor, so + /// don't let it go out of scope while this object exists. + explicit OnlineMatrixFeature(const MatrixBase &mat): mat_(mat) { } + + virtual int32 Dim() const { return mat_.NumCols(); } + + virtual BaseFloat FrameShiftInSeconds() const { + return 0.01f; + } + + virtual int32 NumFramesReady() const { return mat_.NumRows(); } + + virtual void GetFrame(int32 frame, VectorBase *feat) { + feat->CopyFromVec(mat_.Row(frame)); + } + + virtual bool IsLastFrame(int32 frame) const { + return (frame + 1 == mat_.NumRows()); + } + + + private: + const MatrixBase &mat_; +}; + + +// Note the similarity with SlidingWindowCmnOptions, but there +// are also differences. One which doesn't appear in the config +// itself, because it's a difference between the setups, is that +// in OnlineCmn, we carry over data from the previous utterance, +// or, if no previous utterance is available, from global stats, +// or, if previous utterances are available but the total amount +// of data is less than prev_frames, we pad with up to "global_frames" +// frames from the global stats. +struct OnlineCmvnOptions { + int32 cmn_window; + int32 speaker_frames; // must be <= cmn_window + int32 global_frames; // must be <= speaker_frames. + bool normalize_mean; // Must be true if normalize_variance==true. + bool normalize_variance; + + int32 modulus; // not configurable from command line, relates to how the + // class computes the cmvn internally. smaller->more + // time-efficient but less memory-efficient. Must be >= 1. + int32 ring_buffer_size; // not configurable from command line; size of ring + // buffer used for caching CMVN stats. Must be >= + // modulus. + std::string skip_dims; // Colon-separated list of dimensions to skip normalization + // of, e.g. 13:14:15. + + OnlineCmvnOptions(): + cmn_window(600), + speaker_frames(600), + global_frames(200), + normalize_mean(true), + normalize_variance(false), + modulus(20), + ring_buffer_size(20), + skip_dims("") { } + + void Check() const { + KALDI_ASSERT(speaker_frames <= cmn_window && global_frames <= speaker_frames + && modulus > 0); + } + + void Register(ParseOptions *po) { + po->Register("cmn-window", &cmn_window, "Number of frames of sliding " + "context for cepstral mean normalization."); + po->Register("global-frames", &global_frames, "Number of frames of " + "global-average cepstral mean normalization stats to use for " + "first utterance of a speaker"); + po->Register("speaker-frames", &speaker_frames, "Number of frames of " + "previous utterance(s) from this speaker to use in cepstral " + "mean normalization"); + // we name the config string "norm-vars" for compatibility with + // ../featbin/apply-cmvn.cc + po->Register("norm-vars", &normalize_variance, "If true, do " + "cepstral variance normalization in addition to cepstral mean " + "normalization "); + po->Register("norm-means", &normalize_mean, "If true, do mean normalization " + "(note: you cannot normalize the variance but not the mean)"); + po->Register("skip-dims", &skip_dims, "Dimensions to skip normalization of " + "(colon-separated list of integers)");} +}; + + + +/** Struct OnlineCmvnState stores the state of CMVN adaptation between + utterances (but not the state of the computation within an utterance). It + stores the global CMVN stats and the stats of the current speaker (if we + have seen previous utterances for this speaker), and possibly will have a + member "frozen_state": if the user has called the function Freeze() of class + OnlineCmvn, to fix the CMVN so we can estimate fMLLR on top of the fixed + value of cmvn. If nonempty, "frozen_state" will reflect how we were + normalizing the mean and (if applicable) variance at the time when that + function was called. +*/ +struct OnlineCmvnState { + // The following is the total CMVN stats for this speaker (up till now), in + // the same format. + Matrix speaker_cmvn_stats; + + // The following is the global CMVN stats, in the usual + // format, of dimension 2 x (dim+1), as [ sum-stats count + // sum-squared-stats 0 ] + Matrix global_cmvn_stats; + + // If nonempty, contains CMVN stats representing the "frozen" state + // of CMVN that reflects how we were normalizing the data when the + // user called the Freeze() function in class OnlineCmvn. + Matrix frozen_state; + + OnlineCmvnState() { } + + explicit OnlineCmvnState(const Matrix &global_stats): + global_cmvn_stats(global_stats) { } + + // Copy constructor + OnlineCmvnState(const OnlineCmvnState &other); + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + // Use the default assignment operator. +}; + +/** + This class does an online version of the cepstral mean and [optionally] + variance, but note that this is not equivalent to the offline version. This + is necessarily so, as the offline computation involves looking into the + future. If you plan to use features normalized with this type of CMVN then + you need to train in a `matched' way, i.e. with the same type of features. + We normally only do so in the "online" GMM-based decoding, e.g. in + online2bin/online2-wav-gmm-latgen-faster.cc; see also the script + steps/online/prepare_online_decoding.sh and steps/online/decode.sh. + + In the steady state (in the middle of a long utterance), this class + accumulates CMVN statistics from the previous "cmn_window" frames (default 600 + frames, or 6 seconds), and uses these to normalize the mean and possibly + variance of the current frame. + + The config variables "speaker_frames" and "global_frames" relate to what + happens at the beginning of the utterance when we have seen fewer than + "cmn_window" frames of context, and so might not have very good stats to + normalize with. Basically, we first augment any existing stats with up + to "speaker_frames" frames of stats from previous utterances of the current + speaker, and if this doesn't take us up to the required "cmn_window" frame + count, we further augment with up to "global_frames" frames of global + stats. The global stats are CMVN stats accumulated from training or testing + data, that give us a reasonable source of mean and variance for "typical" + data. + */ +class OnlineCmvn: public OnlineFeatureInterface { + public: + + // + // First, functions that are present in the interface: + // + virtual int32 Dim() const { return src_->Dim(); } + + virtual bool IsLastFrame(int32 frame) const { + return src_->IsLastFrame(frame); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + // The online cmvn does not introduce any additional latency. + virtual int32 NumFramesReady() const { return src_->NumFramesReady(); } + + virtual void GetFrame(int32 frame, VectorBase *feat); + + // + // Next, functions that are not in the interface. + // + + /// Initializer that sets the cmvn state. If you don't have previous + /// utterances from the same speaker you are supposed to initialize the CMVN + /// state from some global CMVN stats, which you can get from summing all cmvn + /// stats you have in your training data using "sum-matrix". This just gives + /// it a reasonable starting point at the start of the file. + /// If you do have previous utterances from the same speaker or at least a + /// similar environment, you are supposed to initialize it by calling GetState + /// from the previous utterance + OnlineCmvn(const OnlineCmvnOptions &opts, + const OnlineCmvnState &cmvn_state, + OnlineFeatureInterface *src); + + /// Initializer that does not set the cmvn state: + /// after calling this, you should call SetState(). + OnlineCmvn(const OnlineCmvnOptions &opts, + OnlineFeatureInterface *src); + + // Outputs any state information from this utterance to "cmvn_state". + // The value of "cmvn_state" before the call does not matter: the output + // depends on the value of OnlineCmvnState the class was initialized + // with, the input feature values up to cur_frame, and the effects + // of the user possibly having called Freeze(). + // If cur_frame is -1, it will just output the unmodified original + // state that was supplied to this object. + void GetState(int32 cur_frame, + OnlineCmvnState *cmvn_state); + + // This function can be used to modify the state of the CMVN computation + // from outside, but must only be called before you have processed any data + // (otherwise it will crash). This "state" is really just the information + // that is propagated between utterances, not the state of the computation + // inside an utterance. + void SetState(const OnlineCmvnState &cmvn_state); + + // From this point it will freeze the CMN to what it would have been if + // measured at frame "cur_frame", and it will stop it from changing + // further. This also applies retroactively for this utterance, so if you + // call GetFrame() on previous frames, it will use the CMVN stats + // from cur_frame; and it applies in the future too if you then + // call OutputState() and use this state to initialize the next + // utterance's CMVN object. + void Freeze(int32 cur_frame); + + virtual ~OnlineCmvn(); + private: + + /// Smooth the CMVN stats "stats" (which are stored in the normal format as a + /// 2 x (dim+1) matrix), by possibly adding some stats from "global_stats" + /// and/or "speaker_stats", controlled by the config. The best way to + /// understand the smoothing rule we use is just to look at the code. + static void SmoothOnlineCmvnStats(const MatrixBase &speaker_stats, + const MatrixBase &global_stats, + const OnlineCmvnOptions &opts, + MatrixBase *stats); + + /// Get the most recent cached frame of CMVN stats. [If no frames + /// were cached, sets up empty stats for frame zero and returns that]. + void GetMostRecentCachedFrame(int32 frame, + int32 *cached_frame, + MatrixBase *stats); + + /// Cache this frame of stats. + void CacheFrame(int32 frame, const MatrixBase &stats); + + /// Initialize ring buffer for caching stats. + inline void InitRingBufferIfNeeded(); + + /// Computes the raw CMVN stats for this frame, making use of (and updating if + /// necessary) the cached statistics in raw_stats_. This means the (x, + /// x^2, count) stats for the last up to opts_.cmn_window frames. + void ComputeStatsForFrame(int32 frame, + MatrixBase *stats); + + + OnlineCmvnOptions opts_; + std::vector skip_dims_; // Skip CMVN for these dimensions. Derived from opts_. + OnlineCmvnState orig_state_; // reflects the state before we saw this + // utterance. + Matrix frozen_state_; // If the user called Freeze(), this variable + // will reflect the CMVN state that we froze + // at. + + // The variable below reflects the raw (count, x, x^2) statistics of the + // input, computed every opts_.modulus frames. raw_stats_[n / opts_.modulus] + // contains the (count, x, x^2) statistics for the frames from + // std::max(0, n - opts_.cmn_window) through n. + std::vector*> cached_stats_modulo_; + // the variable below is a ring-buffer of cached stats. the int32 is the + // frame index. + std::vector > > cached_stats_ring_; + + // Some temporary variables used inside functions of this class, which + // put here to avoid reallocation. + Matrix temp_stats_; + Vector temp_feats_; + Vector temp_feats_dbl_; + + OnlineFeatureInterface *src_; // Not owned here +}; + + +struct OnlineSpliceOptions { + int32 left_context; + int32 right_context; + OnlineSpliceOptions(): left_context(4), right_context(4) { } + void Register(ParseOptions *po) { + po->Register("left-context", &left_context, "Left-context for frame " + "splicing prior to LDA"); + po->Register("right-context", &right_context, "Right-context for frame " + "splicing prior to LDA"); + } +}; + +class OnlineSpliceFrames: public OnlineFeatureInterface { + public: + // + // First, functions that are present in the interface: + // + virtual int32 Dim() const { + return src_->Dim() * (1 + left_context_ + right_context_); + } + + virtual bool IsLastFrame(int32 frame) const { + return src_->IsLastFrame(frame); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const; + + virtual void GetFrame(int32 frame, VectorBase *feat); + + // + // Next, functions that are not in the interface. + // + OnlineSpliceFrames(const OnlineSpliceOptions &opts, + OnlineFeatureInterface *src): + left_context_(opts.left_context), right_context_(opts.right_context), + src_(src) { } + + private: + int32 left_context_; + int32 right_context_; + OnlineFeatureInterface *src_; // Not owned here +}; + +/// This online-feature class implements any affine or linear transform. +class OnlineTransform: public OnlineFeatureInterface { + public: + // + // First, functions that are present in the interface: + // + virtual int32 Dim() const { return offset_.Dim(); } + + virtual bool IsLastFrame(int32 frame) const { + return src_->IsLastFrame(frame); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const { return src_->NumFramesReady(); } + + virtual void GetFrame(int32 frame, VectorBase *feat); + + virtual void GetFrames(const std::vector &frames, + MatrixBase *feats); + + // + // Next, functions that are not in the interface. + // + + /// The transform can be a linear transform, or an affine transform + /// where the last column is the offset. + OnlineTransform(const MatrixBase &transform, + OnlineFeatureInterface *src); + + + private: + OnlineFeatureInterface *src_; // Not owned here + Matrix linear_term_; + Vector offset_; +}; + +class OnlineDeltaFeature: public OnlineFeatureInterface { + public: + // + // First, functions that are present in the interface: + // + virtual int32 Dim() const; + + virtual bool IsLastFrame(int32 frame) const { + return src_->IsLastFrame(frame); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const; + + virtual void GetFrame(int32 frame, VectorBase *feat); + + // + // Next, functions that are not in the interface. + // + OnlineDeltaFeature(const DeltaFeaturesOptions &opts, + OnlineFeatureInterface *src); + + private: + OnlineFeatureInterface *src_; // Not owned here + DeltaFeaturesOptions opts_; + DeltaFeatures delta_features_; // This class contains just a few + // coefficients. +}; + + +/// This feature type can be used to cache its input, to avoid +/// repetition of computation in a multi-pass decoding context. +class OnlineCacheFeature: public OnlineFeatureInterface { + public: + virtual int32 Dim() const { return src_->Dim(); } + + virtual bool IsLastFrame(int32 frame) const { + return src_->IsLastFrame(frame); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const { return src_->NumFramesReady(); } + + virtual void GetFrame(int32 frame, VectorBase *feat); + + virtual void GetFrames(const std::vector &frames, + MatrixBase *feats); + + virtual ~OnlineCacheFeature() { ClearCache(); } + + // Things that are not in the shared interface: + + void ClearCache(); // this should be called if you change the underlying + // features in some way. + + explicit OnlineCacheFeature(OnlineFeatureInterface *src): src_(src) { } + private: + + OnlineFeatureInterface *src_; // Not owned here + std::vector* > cache_; +}; + + + + +/// This online-feature class implements combination of two feature +/// streams (such as pitch, plp) into one stream. +class OnlineAppendFeature: public OnlineFeatureInterface { + public: + virtual int32 Dim() const { return src1_->Dim() + src2_->Dim(); } + + virtual bool IsLastFrame(int32 frame) const { + return (src1_->IsLastFrame(frame) || src2_->IsLastFrame(frame)); + } + // Hopefully sources have the same rate + virtual BaseFloat FrameShiftInSeconds() const { + return src1_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const { + return std::min(src1_->NumFramesReady(), src2_->NumFramesReady()); + } + + virtual void GetFrame(int32 frame, VectorBase *feat); + + virtual ~OnlineAppendFeature() { } + + OnlineAppendFeature(OnlineFeatureInterface *src1, + OnlineFeatureInterface *src2): src1_(src1), src2_(src2) { } + private: + + OnlineFeatureInterface *src1_; + OnlineFeatureInterface *src2_; +}; + +/// @} End of "addtogroup onlinefeat" +} // namespace kaldi + +#endif // KALDI_FEAT_ONLINE_FEATURE_H_ diff --git a/speechx/speechx/kaldi/feat/pitch-functions.cc b/speechx/speechx/kaldi/feat/pitch-functions.cc new file mode 100644 index 00000000..430e9bdb --- /dev/null +++ b/speechx/speechx/kaldi/feat/pitch-functions.cc @@ -0,0 +1,1667 @@ +// feat/pitch-functions.cc + +// Copyright 2013 Pegah Ghahremani +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2014 Yanqing Sun, Junjie Wang, +// Daniel Povey, Korbinian Riedhammer +// Xin Lei + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "feat/feature-functions.h" +#include "feat/mel-computations.h" +#include "feat/online-feature.h" +#include "feat/pitch-functions.h" +#include "feat/resample.h" +#include "matrix/matrix-functions.h" + +namespace kaldi { + +/** + This function processes the NCCF n to a POV feature f by applying the formula + f = (1.0001 - n)^0.15 - 1.0 + This is a nonlinear function designed to make the output reasonably Gaussian + distributed. Before doing this, the NCCF distribution is in the range [-1, + 1] but has a strong peak just before 1.0, which this function smooths out. +*/ + +BaseFloat NccfToPovFeature(BaseFloat n) { + if (n > 1.0) { + n = 1.0; + } else if (n < -1.0) { + n = -1.0; + } + BaseFloat f = pow((1.0001 - n), 0.15) - 1.0; + KALDI_ASSERT(f - f == 0); // check for NaN,inf. + return f; +} + +/** + This function processes the NCCF n to a reasonably accurate probability + of voicing p by applying the formula: + + n' = fabs(n) + r = -5.2 + 5.4 * exp(7.5 * (n' - 1.0)) + + 4.8 * n' - 2.0 * exp(-10.0 * n') + 4.2 * exp(20.0 * (n' - 1.0)); + p = 1.0 / (1 + exp(-1.0 * r)); + + How did we get this formula? We plotted the empirical log-prob-ratio of voicing + r = log( p[voiced] / p[not-voiced] ) + [on the Keele database where voicing is marked], as a function of the NCCF at + the delay picked by our algorithm. This was done on intervals of the NCCF, so + we had enough statistics to get that ratio. The NCCF covers [-1, 1]; almost + all of the probability mass is on [0, 1] but the empirical POV seems fairly + symmetric with a minimum near zero, so we chose to make it a function of n' = fabs(n). + + Then we manually tuned a function (the one you see above) that approximated + the log-prob-ratio of voicing fairly well as a function of the absolute-value + NCCF n'; however, wasn't a very exact match since we were also trying to make + the transformed NCCF fairly Gaussian distributed, with a view to using it as + a feature-- an idea we later abandoned after a simpler formula worked better. + */ +BaseFloat NccfToPov(BaseFloat n) { + BaseFloat ndash = fabs(n); + if (ndash > 1.0) ndash = 1.0; // just in case it was slightly outside [-1, 1] + + BaseFloat r = -5.2 + 5.4 * Exp(7.5 * (ndash - 1.0)) + 4.8 * ndash - + 2.0 * Exp(-10.0 * ndash) + 4.2 * Exp(20.0 * (ndash - 1.0)); + // r is the approximate log-prob-ratio of voicing, log(p/(1-p)). + BaseFloat p = 1.0 / (1 + Exp(-1.0 * r)); + KALDI_ASSERT(p - p == 0); // Check for NaN/inf + return p; +} + +/** + This function computes some dot products that are required + while computing the NCCF. + For each integer lag from start to end-1, this function + outputs to (*inner_prod)(lag - start), the dot-product + of a window starting at 0 with a window starting at + lag. All windows are of length nccf_window_size. It + outputs to (*norm_prod)(lag - start), e1 * e2, where + e1 is the dot-product of the un-shifted window with itself, + and d2 is the dot-product of the window shifted by "lag" + with itself. + */ +void ComputeCorrelation(const VectorBase &wave, + int32 first_lag, int32 last_lag, + int32 nccf_window_size, + VectorBase *inner_prod, + VectorBase *norm_prod) { + Vector zero_mean_wave(wave); + // TODO: possibly fix this, the mean normalization is done in a strange way. + SubVector wave_part(wave, 0, nccf_window_size); + // subtract mean-frame from wave + zero_mean_wave.Add(-wave_part.Sum() / nccf_window_size); + BaseFloat e1, e2, sum; + SubVector sub_vec1(zero_mean_wave, 0, nccf_window_size); + e1 = VecVec(sub_vec1, sub_vec1); + for (int32 lag = first_lag; lag <= last_lag; lag++) { + SubVector sub_vec2(zero_mean_wave, lag, nccf_window_size); + e2 = VecVec(sub_vec2, sub_vec2); + sum = VecVec(sub_vec1, sub_vec2); + (*inner_prod)(lag - first_lag) = sum; + (*norm_prod)(lag - first_lag) = e1 * e2; + } +} + +/** + Computes the NCCF as a fraction of the numerator term (a dot product between + two vectors) and a denominator term which equals sqrt(e1*e2 + nccf_ballast) + where e1 and e2 are both dot-products of bits of the wave with themselves, + and e1*e2 is supplied as "norm_prod". These quantities are computed by + "ComputeCorrelation". +*/ +void ComputeNccf(const VectorBase &inner_prod, + const VectorBase &norm_prod, + BaseFloat nccf_ballast, + VectorBase *nccf_vec) { + KALDI_ASSERT(inner_prod.Dim() == norm_prod.Dim() && + inner_prod.Dim() == nccf_vec->Dim()); + for (int32 lag = 0; lag < inner_prod.Dim(); lag++) { + BaseFloat numerator = inner_prod(lag), + denominator = pow(norm_prod(lag) + nccf_ballast, 0.5), + nccf; + if (denominator != 0.0) { + nccf = numerator / denominator; + } else { + KALDI_ASSERT(numerator == 0.0); + nccf = 0.0; + } + KALDI_ASSERT(nccf < 1.01 && nccf > -1.01); + (*nccf_vec)(lag) = nccf; + } +} + +/** + This function selects the lags at which we measure the NCCF: we need + to select lags from 1/max_f0 to 1/min_f0, in a geometric progression + with ratio 1 + d. + */ +void SelectLags(const PitchExtractionOptions &opts, + Vector *lags) { + // choose lags relative to acceptable pitch tolerance + BaseFloat min_lag = 1.0 / opts.max_f0, max_lag = 1.0 / opts.min_f0; + + std::vector tmp_lags; + for (BaseFloat lag = min_lag; lag <= max_lag; lag *= 1.0 + opts.delta_pitch) + tmp_lags.push_back(lag); + lags->Resize(tmp_lags.size()); + std::copy(tmp_lags.begin(), tmp_lags.end(), lags->Data()); +} + + +/** + This function computes the local-cost for the Viterbi computation, + see eq. (5) in the paper. + @param opts The options as provided by the user + @param nccf_pitch The nccf as computed for the pitch computation (with ballast). + @param lags The log-spaced lags at which nccf_pitch is sampled. + @param local_cost We output the local-cost to here. +*/ +void ComputeLocalCost(const VectorBase &nccf_pitch, + const VectorBase &lags, + const PitchExtractionOptions &opts, + VectorBase *local_cost) { + // from the paper, eq. 5, local_cost = 1 - Phi(t,i)(1 - soft_min_f0 L_i) + // nccf is the nccf on this frame measured at the lags in "lags". + KALDI_ASSERT(nccf_pitch.Dim() == local_cost->Dim() && + nccf_pitch.Dim() == lags.Dim()); + local_cost->Set(1.0); + // add the term -Phi(t,i): + local_cost->AddVec(-1.0, nccf_pitch); + // add the term soft_min_f0 Phi(t,i) L_i + local_cost->AddVecVec(opts.soft_min_f0, lags, nccf_pitch, 1.0); +} + + + +// class PitchFrameInfo is used inside class OnlinePitchFeatureImpl. +// It stores the information we need to keep around for a single frame +// of the pitch computation. +class PitchFrameInfo { + public: + /// This function resizes the arrays for this object and updates the reference + /// counts for the previous object (by decrementing those reference counts + /// when we destroy a StateInfo object). A StateInfo object is considered to + /// be destroyed when we delete it, not when its reference counts goes to + /// zero. + void Cleanup(PitchFrameInfo *prev_frame); + + /// This function may be called for the last (most recent) PitchFrameInfo + /// object with the best state (obtained from the externally held + /// forward-costs). It traces back as far as needed to set the + /// cur_best_state_, and as it's going it sets the lag-index and pov_nccf in + /// pitch_pov_iter, which when it's called is an iterator to where to put the + /// info for the final state; the iterator will be decremented inside this + /// function. + void SetBestState(int32 best_state, + std::vector > &lag_nccf); + + /// This function may be called on the last (most recent) PitchFrameInfo + /// object; it computes how many frames of latency there is because the + /// traceback has not yet settled on a single value for frames in the past. + /// It actually returns the minimum of max_latency and the actual latency, + /// which is an optimization because we won't care about latency past + /// a user-specified maximum latency. + int32 ComputeLatency(int32 max_latency); + + /// This function updates + bool UpdatePreviousBestState(PitchFrameInfo *prev_frame); + + /// This constructor is used for frame -1; it sets the costs to be all zeros + /// the pov_nccf's to zero and the backpointers to -1. + explicit PitchFrameInfo(int32 num_states); + + /// This constructor is used for subsequent frames (not -1). + PitchFrameInfo(PitchFrameInfo *prev); + + /// Record the nccf_pov value. + /// @param nccf_pov The nccf as computed for the POV computation (without ballast). + void SetNccfPov(const VectorBase &nccf_pov); + + /// This constructor is used for frames apart from frame -1; the bulk of + /// the Viterbi computation takes place inside this constructor. + /// @param opts The options as provided by the user + /// @param nccf_pitch The nccf as computed for the pitch computation + /// (with ballast). + /// @param nccf_pov The nccf as computed for the POV computation + /// (without ballast). + /// @param lags The log-spaced lags at which nccf_pitch and + /// nccf_pov are sampled. + /// @param prev_frame_forward_cost The forward-cost vector for the + /// previous frame. + /// @param index_info A pointer to a temporary vector used by this function + /// @param this_forward_cost The forward-cost vector for this frame + /// (to be computed). + void ComputeBacktraces(const PitchExtractionOptions &opts, + const VectorBase &nccf_pitch, + const VectorBase &lags, + const VectorBase &prev_forward_cost, + std::vector > *index_info, + VectorBase *this_forward_cost); + private: + // struct StateInfo is the information we keep for a single one of the + // log-spaced lags, for a single frame. This is a state in the Viterbi + // computation. + struct StateInfo { + /// The state index on the previous frame that is the best preceding state + /// for this state. + int32 backpointer; + /// the version of the NCCF we keep for the POV computation (without the + /// ballast term). + BaseFloat pov_nccf; + StateInfo(): backpointer(0), pov_nccf(0.0) { } + }; + std::vector state_info_; + /// the state index of the first entry in "state_info"; this will initially be + /// zero, but after cleanup might be nonzero. + int32 state_offset_; + + /// The current best state in the backtrace from the end. + int32 cur_best_state_; + + /// The structure for the previous frame. + PitchFrameInfo *prev_info_; +}; + + +// This constructor is used for frame -1; it sets the costs to be all zeros +// the pov_nccf's to zero and the backpointers to -1. +PitchFrameInfo::PitchFrameInfo(int32 num_states) + :state_info_(num_states), state_offset_(0), + cur_best_state_(-1), prev_info_(NULL) { } + + +bool pitch_use_naive_search = false; // This is used in unit-tests. + + +PitchFrameInfo::PitchFrameInfo(PitchFrameInfo *prev_info): + state_info_(prev_info->state_info_.size()), state_offset_(0), + cur_best_state_(-1), prev_info_(prev_info) { } + +void PitchFrameInfo::SetNccfPov(const VectorBase &nccf_pov) { + int32 num_states = nccf_pov.Dim(); + KALDI_ASSERT(num_states == state_info_.size()); + for (int32 i = 0; i < num_states; i++) + state_info_[i].pov_nccf = nccf_pov(i); +} + +void PitchFrameInfo::ComputeBacktraces( + const PitchExtractionOptions &opts, + const VectorBase &nccf_pitch, + const VectorBase &lags, + const VectorBase &prev_forward_cost_vec, + std::vector > *index_info, + VectorBase *this_forward_cost_vec) { + int32 num_states = nccf_pitch.Dim(); + + Vector local_cost(num_states, kUndefined); + ComputeLocalCost(nccf_pitch, lags, opts, &local_cost); + + const BaseFloat delta_pitch_sq = pow(Log(1.0 + opts.delta_pitch), 2.0), + inter_frame_factor = delta_pitch_sq * opts.penalty_factor; + + // index local_cost, prev_forward_cost and this_forward_cost using raw pointer + // indexing not operator (), since this is the very inner loop and a lot of + // time is taken here. + const BaseFloat *prev_forward_cost = prev_forward_cost_vec.Data(); + BaseFloat *this_forward_cost = this_forward_cost_vec->Data(); + + if (index_info->empty()) + index_info->resize(num_states); + + // make it a reference for more concise indexing. + std::vector > &bounds = *index_info; + + /* bounds[i].first will be a lower bound on the backpointer for state i, + bounds[i].second will be an upper bound on it. We progressively tighten + these bounds till we know the backpointers exactly. + */ + + if (pitch_use_naive_search) { + // This branch is only taken in unit-testing code. + for (int32 i = 0; i < num_states; i++) { + BaseFloat best_cost = std::numeric_limits::infinity(); + int32 best_j = -1; + for (int32 j = 0; j < num_states; j++) { + BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor + + prev_forward_cost[j]; + if (this_cost < best_cost) { + best_cost = this_cost; + best_j = j; + } + } + this_forward_cost[i] = best_cost; + state_info_[i].backpointer = best_j; + } + } else { + int32 last_backpointer = 0; + for (int32 i = 0; i < num_states; i++) { + int32 start_j = last_backpointer; + BaseFloat best_cost = (start_j - i) * (start_j - i) * inter_frame_factor + + prev_forward_cost[start_j]; + int32 best_j = start_j; + + for (int32 j = start_j + 1; j < num_states; j++) { + BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor + + prev_forward_cost[j]; + if (this_cost < best_cost) { + best_cost = this_cost; + best_j = j; + } else { // as soon as the costs stop improving, we stop searching. + break; // this is a loose lower bound we're getting. + } + } + state_info_[i].backpointer = best_j; + this_forward_cost[i] = best_cost; + bounds[i].first = best_j; // this is now a lower bound on the + // backpointer. + bounds[i].second = num_states - 1; // we have no meaningful upper bound + // yet. + last_backpointer = best_j; + } + + // We iterate, progressively refining the upper and lower bounds until they + // meet and we know that the resulting backtraces are optimal. Each + // iteration takes time linear in num_states. We won't normally iterate as + // far as num_states; normally we only do two iterations; when printing out + // the number of iterations, it's rarely more than that (once I saw seven + // iterations). Anyway, this part of the computation does not dominate. + for (int32 iter = 0; iter < num_states; iter++) { + bool changed = false; + if (iter % 2 == 0) { // go backwards through the states + last_backpointer = num_states - 1; + for (int32 i = num_states - 1; i >= 0; i--) { + int32 lower_bound = bounds[i].first, + upper_bound = std::min(last_backpointer, bounds[i].second); + if (upper_bound == lower_bound) { + last_backpointer = lower_bound; + continue; + } + BaseFloat best_cost = this_forward_cost[i]; + int32 best_j = state_info_[i].backpointer, initial_best_j = best_j; + + if (best_j == upper_bound) { + // if best_j already equals upper bound, don't bother tightening the + // upper bound, we'll tighten the lower bound when the time comes. + last_backpointer = best_j; + continue; + } + // Below, we have j > lower_bound + 1 because we know we've already + // evaluated lower_bound and lower_bound + 1 [via knowledge of + // this algorithm.] + for (int32 j = upper_bound; j > lower_bound + 1; j--) { + BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor + + prev_forward_cost[j]; + if (this_cost < best_cost) { + best_cost = this_cost; + best_j = j; + } else { // as soon as the costs stop improving, we stop searching, + // unless the best j is still lower than j, in which case + // we obviously need to keep moving. + if (best_j > j) + break; // this is a loose lower bound we're getting. + } + } + // our "best_j" is now an upper bound on the backpointer. + bounds[i].second = best_j; + if (best_j != initial_best_j) { + this_forward_cost[i] = best_cost; + state_info_[i].backpointer = best_j; + changed = true; + } + last_backpointer = best_j; + } + } else { // go forwards through the states. + last_backpointer = 0; + for (int32 i = 0; i < num_states; i++) { + int32 lower_bound = std::max(last_backpointer, bounds[i].first), + upper_bound = bounds[i].second; + if (upper_bound == lower_bound) { + last_backpointer = lower_bound; + continue; + } + BaseFloat best_cost = this_forward_cost[i]; + int32 best_j = state_info_[i].backpointer, initial_best_j = best_j; + + if (best_j == lower_bound) { + // if best_j already equals lower bound, we don't bother tightening + // the lower bound, we'll tighten the upper bound when the time + // comes. + last_backpointer = best_j; + continue; + } + // Below, we have j < upper_bound because we know we've already + // evaluated that point. + for (int32 j = lower_bound; j < upper_bound - 1; j++) { + BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor + + prev_forward_cost[j]; + if (this_cost < best_cost) { + best_cost = this_cost; + best_j = j; + } else { // as soon as the costs stop improving, we stop searching, + // unless the best j is still higher than j, in which case + // we obviously need to keep moving. + if (best_j < j) + break; // this is a loose lower bound we're getting. + } + } + // our "best_j" is now a lower bound on the backpointer. + bounds[i].first = best_j; + if (best_j != initial_best_j) { + this_forward_cost[i] = best_cost; + state_info_[i].backpointer = best_j; + changed = true; + } + last_backpointer = best_j; + } + } + if (!changed) + break; + } + } + // The next statement is needed due to RecomputeBacktraces: we have to + // invalidate the previously computed best-state info. + cur_best_state_ = -1; + this_forward_cost_vec->AddVec(1.0, local_cost); +} + +void PitchFrameInfo::SetBestState( + int32 best_state, + std::vector > &lag_nccf) { + + // This function would naturally be recursive, but we have coded this to avoid + // recursion, which would otherwise eat up the stack. Think of it as a static + // member function, except we do use "this" right at the beginning. + + std::vector >::reverse_iterator iter = lag_nccf.rbegin(); + + PitchFrameInfo *this_info = this; // it will change in the loop. + while (this_info != NULL) { + PitchFrameInfo *prev_info = this_info->prev_info_; + if (best_state == this_info->cur_best_state_) + return; // no change + if (prev_info != NULL) // don't write anything for frame -1. + iter->first = best_state; + size_t state_info_index = best_state - this_info->state_offset_; + KALDI_ASSERT(state_info_index < this_info->state_info_.size()); + this_info->cur_best_state_ = best_state; + best_state = this_info->state_info_[state_info_index].backpointer; + if (prev_info != NULL) // don't write anything for frame -1. + iter->second = this_info->state_info_[state_info_index].pov_nccf; + this_info = prev_info; + if (this_info != NULL) ++iter; + } +} + +int32 PitchFrameInfo::ComputeLatency(int32 max_latency) { + if (max_latency <= 0) return 0; + + int32 latency = 0; + + // This function would naturally be recursive, but we have coded this to avoid + // recursion, which would otherwise eat up the stack. Think of it as a static + // member function, except we do use "this" right at the beginning. + // This function is called only on the most recent PitchFrameInfo object. + int32 num_states = state_info_.size(); + int32 min_living_state = 0, max_living_state = num_states - 1; + PitchFrameInfo *this_info = this; // it will change in the loop. + + + for (; this_info != NULL && latency < max_latency;) { + int32 offset = this_info->state_offset_; + KALDI_ASSERT(min_living_state >= offset && + max_living_state - offset < this_info->state_info_.size()); + min_living_state = + this_info->state_info_[min_living_state - offset].backpointer; + max_living_state = + this_info->state_info_[max_living_state - offset].backpointer; + if (min_living_state == max_living_state) { + return latency; + } + this_info = this_info->prev_info_; + if (this_info != NULL) // avoid incrementing latency for frame -1, + latency++; // as it's not a real frame. + } + return latency; +} + +void PitchFrameInfo::Cleanup(PitchFrameInfo *prev_frame) { + KALDI_ERR << "Cleanup not implemented."; +} + + +// struct NccfInfo is used to cache certain quantities that we need for online +// operation, for the first "recompute_frame" frames of the file (e.g. 300); +// after that many frames, or after the user calls InputFinished(), we redo the +// initial backtraces, as we'll then have a better estimate of the average signal +// energy. +struct NccfInfo { + + Vector nccf_pitch_resampled; // resampled nccf_pitch + BaseFloat avg_norm_prod; // average value of e1 * e2. + BaseFloat mean_square_energy; // mean_square energy we used when computing the + // original ballast term for + // "nccf_pitch_resampled". + + NccfInfo(BaseFloat avg_norm_prod, + BaseFloat mean_square_energy): + avg_norm_prod(avg_norm_prod), + mean_square_energy(mean_square_energy) { } +}; + + + +// We could inherit from OnlineBaseFeature as we have the same interface, +// but this will unnecessary force a lot of our functions to be virtual. +class OnlinePitchFeatureImpl { + public: + explicit OnlinePitchFeatureImpl(const PitchExtractionOptions &opts); + + int32 Dim() const { return 2; } + + BaseFloat FrameShiftInSeconds() const; + + int32 NumFramesReady() const; + + bool IsLastFrame(int32 frame) const; + + void GetFrame(int32 frame, VectorBase *feat); + + void AcceptWaveform(BaseFloat sampling_rate, + const VectorBase &waveform); + + void InputFinished(); + + ~OnlinePitchFeatureImpl(); + + + // Copy-constructor, can be used to obtain a new copy of this object, + // any state from this utterance. + OnlinePitchFeatureImpl(const OnlinePitchFeatureImpl &other); + + private: + + /// This function works out from the signal how many frames are currently + /// available to process (this is called from inside AcceptWaveform()). + /// Note: the number of frames differs slightly from the number the + /// old pitch code gave. + /// Note: the number this returns depends on whether input_finished_ == true; + /// if it is, it will "force out" a final frame or two. + int32 NumFramesAvailable(int64 num_downsampled_samples, bool snip_edges) const; + + /// This function extracts from the signal the samples numbered from + /// "sample_index" (numbered in the full downsampled signal, not just this + /// part), and of length equal to window->Dim(). It uses the data members + /// downsampled_samples_discarded_ and downsampled_signal_remainder_, as well + /// as the more recent part of the downsampled wave "downsampled_wave_part" + /// which is provided. + /// + /// @param downsampled_wave_part One chunk of the downsampled wave, + /// starting from sample-index downsampled_samples_discarded_. + /// @param sample_index The desired starting sample index (measured from + /// the start of the whole signal, not just this part). + /// @param window The part of the signal is output to here. + void ExtractFrame(const VectorBase &downsampled_wave_part, + int64 frame_index, + VectorBase *window); + + + /// This function is called after we reach frame "recompute_frame", or when + /// InputFinished() is called, whichever comes sooner. It recomputes the + /// backtraces for frames zero through recompute_frame, if needed because the + /// average energy of the signal has changed, affecting the nccf ballast term. + /// It works out the average signal energy from + /// downsampled_samples_processed_, signal_sum_ and signal_sumsq_ (which, if + /// you see the calling code, might include more frames than just + /// "recompute_frame", it might include up to the end of the current chunk). + void RecomputeBacktraces(); + + + /// This function updates downsampled_signal_remainder_, + /// downsampled_samples_processed_, signal_sum_ and signal_sumsq_; it's called + /// from AcceptWaveform(). + void UpdateRemainder(const VectorBase &downsampled_wave_part); + + + // The following variables don't change throughout the lifetime + // of this object. + PitchExtractionOptions opts_; + + // the first lag of the downsampled signal at which we measure NCCF + int32 nccf_first_lag_; + // the last lag of the downsampled signal at which we measure NCCF + int32 nccf_last_lag_; + + // The log-spaced lags at which we will resample the NCCF + Vector lags_; + + // This object is used to resample from evenly spaced to log-evenly-spaced + // nccf values. It's a pointer for convenience of initialization, so we don't + // have to use the initializer from the constructor. + ArbitraryResample *nccf_resampler_; + + // The following objects may change during the lifetime of this object. + + // This object is used to resample the signal. + LinearResample *signal_resampler_; + + // frame_info_ is indexed by [frame-index + 1]. frame_info_[0] is an object + // that corresponds to frame -1, which is not a real frame. + std::vector frame_info_; + + + // nccf_info_ is indexed by frame-index, from frame 0 to at most + // opts_.recompute_frame - 1. It contains some information we'll + // need to recompute the tracebacks after getting a better estimate + // of the average energy of the signal. + std::vector nccf_info_; + + // Current number of frames which we can't output because Viterbi has not + // converged for them, or opts_.max_frames_latency if we have reached that + // limit. + int32 frames_latency_; + + // The forward-cost at the current frame (the last frame in frame_info_); + // this has the same dimension as lags_. We normalize each time so + // the lowest cost is zero, for numerical accuracy and so we can use float. + Vector forward_cost_; + + // stores the constant part of forward_cost_. + double forward_cost_remainder_; + + // The resampled-lag index and the NCCF (as computed for POV, without ballast + // term) for each frame, as determined by Viterbi traceback from the best + // final state. + std::vector > lag_nccf_; + + bool input_finished_; + + /// sum-squared of previously processed parts of signal; used to get NCCF + /// ballast term. Denominator is downsampled_samples_processed_. + double signal_sumsq_; + + /// sum of previously processed parts of signal; used to do mean-subtraction + /// when getting sum-squared, along with signal_sumsq_. + double signal_sum_; + + /// downsampled_samples_processed is the number of samples (after + /// downsampling) that we got in previous calls to AcceptWaveform(). + int64 downsampled_samples_processed_; + /// This is a small remainder of the previous downsampled signal; + /// it's used by ExtractFrame for frames near the boundary of two + /// waveforms supplied to AcceptWaveform(). + Vector downsampled_signal_remainder_; +}; + + +OnlinePitchFeatureImpl::OnlinePitchFeatureImpl( + const PitchExtractionOptions &opts): + opts_(opts), forward_cost_remainder_(0.0), input_finished_(false), + signal_sumsq_(0.0), signal_sum_(0.0), downsampled_samples_processed_(0) { + signal_resampler_ = new LinearResample(opts.samp_freq, opts.resample_freq, + opts.lowpass_cutoff, + opts.lowpass_filter_width); + + double outer_min_lag = 1.0 / opts.max_f0 - + (opts.upsample_filter_width/(2.0 * opts.resample_freq)); + double outer_max_lag = 1.0 / opts.min_f0 + + (opts.upsample_filter_width/(2.0 * opts.resample_freq)); + nccf_first_lag_ = ceil(opts.resample_freq * outer_min_lag); + nccf_last_lag_ = floor(opts.resample_freq * outer_max_lag); + + frames_latency_ = 0; // will be set in AcceptWaveform() + + // Choose the lags at which we resample the NCCF. + SelectLags(opts, &lags_); + + // upsample_cutoff is the filter cutoff for upsampling the NCCF, which is the + // Nyquist of the resampling frequency. The NCCF is (almost completely) + // bandlimited to around "lowpass_cutoff" (1000 by default), and when the + // spectrum of this bandlimited signal is convolved with the spectrum of an + // impulse train with frequency "resample_freq", which are separated by 4kHz, + // we get energy at -5000,-3000, -1000...1000, 3000..5000, etc. Filtering at + // half the Nyquist (2000 by default) is sufficient to get only the first + // repetition. + BaseFloat upsample_cutoff = opts.resample_freq * 0.5; + + + Vector lags_offset(lags_); + // lags_offset equals lags_ (which are the log-spaced lag values we want to + // measure the NCCF at) with nccf_first_lag_ / opts.resample_freq subtracted + // from each element, so we can treat the measured NCCF values as as starting + // from sample zero in a signal that starts at the point start / + // opts.resample_freq. This is necessary because the ArbitraryResample code + // assumes that the input signal starts from sample zero. + lags_offset.Add(-nccf_first_lag_ / opts.resample_freq); + + int32 num_measured_lags = nccf_last_lag_ + 1 - nccf_first_lag_; + + nccf_resampler_ = new ArbitraryResample(num_measured_lags, opts.resample_freq, + upsample_cutoff, lags_offset, + opts.upsample_filter_width); + + // add a PitchInfo object for frame -1 (not a real frame). + frame_info_.push_back(new PitchFrameInfo(lags_.Dim())); + // zeroes forward_cost_; this is what we want for the fake frame -1. + forward_cost_.Resize(lags_.Dim()); +} + + +int32 OnlinePitchFeatureImpl::NumFramesAvailable( + int64 num_downsampled_samples, bool snip_edges) const { + int32 frame_shift = opts_.NccfWindowShift(), + frame_length = opts_.NccfWindowSize(); + // Use the "full frame length" to compute the number + // of frames only if the input is not finished. + if (!input_finished_) + frame_length += nccf_last_lag_; + if (num_downsampled_samples < frame_length) { + return 0; + } else { + if (!snip_edges) { + if (input_finished_) { + return static_cast(num_downsampled_samples * 1.0f / + frame_shift + 0.5f); + } else { + return static_cast((num_downsampled_samples - frame_length / 2) * + 1.0f / frame_shift + 0.5f); + } + } else { + return static_cast((num_downsampled_samples - frame_length) / + frame_shift + 1); + } + } +} + +void OnlinePitchFeatureImpl::UpdateRemainder( + const VectorBase &downsampled_wave_part) { + // frame_info_ has an extra element at frame-1, so subtract + // one from the length. + int64 num_frames = static_cast(frame_info_.size()) - 1, + next_frame = num_frames, + frame_shift = opts_.NccfWindowShift(), + next_frame_sample = frame_shift * next_frame; + + signal_sumsq_ += VecVec(downsampled_wave_part, downsampled_wave_part); + signal_sum_ += downsampled_wave_part.Sum(); + + // next_frame_sample is the first sample index we'll need for the + // next frame. + int64 next_downsampled_samples_processed = + downsampled_samples_processed_ + downsampled_wave_part.Dim(); + + if (next_frame_sample > next_downsampled_samples_processed) { + // this could only happen in the weird situation that the full frame length + // is less than the frame shift. + int32 full_frame_length = opts_.NccfWindowSize() + nccf_last_lag_; + KALDI_ASSERT(full_frame_length < frame_shift && "Code error"); + downsampled_signal_remainder_.Resize(0); + } else { + Vector new_remainder(next_downsampled_samples_processed - + next_frame_sample); + // note: next_frame_sample is the index into the entire signal, of + // new_remainder(0). + // i is the absolute index of the signal. + for (int64 i = next_frame_sample; + i < next_downsampled_samples_processed; i++) { + if (i >= downsampled_samples_processed_) { // in current signal. + new_remainder(i - next_frame_sample) = + downsampled_wave_part(i - downsampled_samples_processed_); + } else { // in old remainder; only reach here if waveform supplied is + new_remainder(i - next_frame_sample) = // tiny. + downsampled_signal_remainder_(i - downsampled_samples_processed_ + + downsampled_signal_remainder_.Dim()); + } + } + downsampled_signal_remainder_.Swap(&new_remainder); + } + downsampled_samples_processed_ = next_downsampled_samples_processed; +} + +void OnlinePitchFeatureImpl::ExtractFrame( + const VectorBase &downsampled_wave_part, + int64 sample_index, + VectorBase *window) { + int32 full_frame_length = window->Dim(); + int32 offset = static_cast(sample_index - + downsampled_samples_processed_); + + // Treat edge cases first + if (sample_index < 0) { + // Part of the frame is before the beginning of the signal. This + // should only happen if opts_.snip_edges == false, when we are + // processing the first few frames of signal. In this case + // we pad with zeros. + KALDI_ASSERT(opts_.snip_edges == false); + int32 sub_frame_length = sample_index + full_frame_length; + int32 sub_frame_index = full_frame_length - sub_frame_length; + KALDI_ASSERT(sub_frame_length > 0 && sub_frame_index > 0); + window->SetZero(); + SubVector sub_window(*window, sub_frame_index, sub_frame_length); + ExtractFrame(downsampled_wave_part, 0, &sub_window); + return; + } + + if (offset + full_frame_length > downsampled_wave_part.Dim()) { + // Requested frame is past end of the signal. This should only happen if + // input_finished_ == true, when we're flushing out the last couple of + // frames of signal. In this case we pad with zeros. + KALDI_ASSERT(input_finished_); + int32 sub_frame_length = downsampled_wave_part.Dim() - offset; + KALDI_ASSERT(sub_frame_length > 0); + window->SetZero(); + SubVector sub_window(*window, 0, sub_frame_length); + ExtractFrame(downsampled_wave_part, sample_index, &sub_window); + return; + } + + // "offset" is the offset of the start of the frame, into this + // signal. + if (offset >= 0) { + // frame is full inside the new part of the signal. + window->CopyFromVec(downsampled_wave_part.Range(offset, full_frame_length)); + } else { + // frame is partly in the remainder and partly in the new part. + int32 remainder_offset = downsampled_signal_remainder_.Dim() + offset; + KALDI_ASSERT(remainder_offset >= 0); // or we didn't keep enough remainder. + KALDI_ASSERT(offset + full_frame_length > 0); // or we should have + // processed this frame last + // time. + + int32 old_length = -offset, new_length = offset + full_frame_length; + window->Range(0, old_length).CopyFromVec( + downsampled_signal_remainder_.Range(remainder_offset, old_length)); + window->Range(old_length, new_length).CopyFromVec( + downsampled_wave_part.Range(0, new_length)); + } + if (opts_.preemph_coeff != 0.0) { + BaseFloat preemph_coeff = opts_.preemph_coeff; + for (int32 i = window->Dim() - 1; i > 0; i--) + (*window)(i) -= preemph_coeff * (*window)(i-1); + (*window)(0) *= (1.0 - preemph_coeff); + } +} + +bool OnlinePitchFeatureImpl::IsLastFrame(int32 frame) const { + int32 T = NumFramesReady(); + KALDI_ASSERT(frame < T); + return (input_finished_ && frame + 1 == T); +} + +BaseFloat OnlinePitchFeatureImpl::FrameShiftInSeconds() const { + return opts_.frame_shift_ms / 1000.0f; +} + +int32 OnlinePitchFeatureImpl::NumFramesReady() const { + int32 num_frames = lag_nccf_.size(), + latency = frames_latency_; + KALDI_ASSERT(latency <= num_frames); + return num_frames - latency; +} + + +void OnlinePitchFeatureImpl::GetFrame(int32 frame, + VectorBase *feat) { + KALDI_ASSERT(frame < NumFramesReady() && feat->Dim() == 2); + (*feat)(0) = lag_nccf_[frame].second; + (*feat)(1) = 1.0 / lags_(lag_nccf_[frame].first); +} + +void OnlinePitchFeatureImpl::InputFinished() { + input_finished_ = true; + // Process an empty waveform; this has an effect because + // after setting input_finished_ to true, NumFramesAvailable() + // will return a slightly larger number. + AcceptWaveform(opts_.samp_freq, Vector()); + int32 num_frames = static_cast(frame_info_.size() - 1); + if (num_frames < opts_.recompute_frame && !opts_.nccf_ballast_online) + RecomputeBacktraces(); + frames_latency_ = 0; + KALDI_VLOG(3) << "Pitch-tracking Viterbi cost is " + << (forward_cost_remainder_ / num_frames) + << " per frame, over " << num_frames << " frames."; +} + +// see comment with declaration. This is only relevant for online +// operation (it gets called for non-online mode, but is a no-op). +void OnlinePitchFeatureImpl::RecomputeBacktraces() { + KALDI_ASSERT(!opts_.nccf_ballast_online); + int32 num_frames = static_cast(frame_info_.size()) - 1; + + // The assertion reflects how we believe this function will be called. + KALDI_ASSERT(num_frames <= opts_.recompute_frame); + KALDI_ASSERT(nccf_info_.size() == static_cast(num_frames)); + if (num_frames == 0) + return; + double num_samp = downsampled_samples_processed_, sum = signal_sum_, + sumsq = signal_sumsq_, mean = sum / num_samp; + BaseFloat mean_square = sumsq / num_samp - mean * mean; + + bool must_recompute = false; + BaseFloat threshold = 0.01; + for (int32 frame = 0; frame < num_frames; frame++) + if (!ApproxEqual(nccf_info_[frame]->mean_square_energy, + mean_square, threshold)) + must_recompute = true; + + if (!must_recompute) { + // Nothing to do. We'll reach here, for instance, if everything was in one + // chunk and opts_.nccf_ballast_online == false. This is the case for + // offline processing. + for (size_t i = 0; i < nccf_info_.size(); i++) + delete nccf_info_[i]; + nccf_info_.clear(); + return; + } + + int32 num_states = forward_cost_.Dim(), + basic_frame_length = opts_.NccfWindowSize(); + + BaseFloat new_nccf_ballast = pow(mean_square * basic_frame_length, 2) * + opts_.nccf_ballast; + + double forward_cost_remainder = 0.0; + Vector forward_cost(num_states), // start off at zero. + next_forward_cost(forward_cost); + std::vector > index_info; + + for (int32 frame = 0; frame < num_frames; frame++) { + NccfInfo &nccf_info = *nccf_info_[frame]; + BaseFloat old_mean_square = nccf_info_[frame]->mean_square_energy, + avg_norm_prod = nccf_info_[frame]->avg_norm_prod, + old_nccf_ballast = pow(old_mean_square * basic_frame_length, 2) * + opts_.nccf_ballast, + nccf_scale = pow((old_nccf_ballast + avg_norm_prod) / + (new_nccf_ballast + avg_norm_prod), + static_cast(0.5)); + // The "nccf_scale" is an estimate of the scaling factor by which the NCCF + // would change on this frame, on average, by changing the ballast term from + // "old_nccf_ballast" to "new_nccf_ballast". It's not exact because the + // "avg_norm_prod" is just an average of the product e1 * e2 of frame + // energies of the (frame, shifted-frame), but these won't change that much + // within a frame, and even if they do, the inaccuracy of the scaled NCCF + // will still be very small if the ballast term didn't change much, or if + // it's much larger or smaller than e1*e2. By doing it as a simple scaling, + // we save the overhead of the NCCF resampling, which is a considerable part + // of the whole computation. + nccf_info.nccf_pitch_resampled.Scale(nccf_scale); + + frame_info_[frame + 1]->ComputeBacktraces( + opts_, nccf_info.nccf_pitch_resampled, lags_, + forward_cost, &index_info, &next_forward_cost); + + forward_cost.Swap(&next_forward_cost); + BaseFloat remainder = forward_cost.Min(); + forward_cost_remainder += remainder; + forward_cost.Add(-remainder); + } + KALDI_VLOG(3) << "Forward-cost per frame changed from " + << (forward_cost_remainder_ / num_frames) << " to " + << (forward_cost_remainder / num_frames); + + forward_cost_remainder_ = forward_cost_remainder; + forward_cost_.Swap(&forward_cost); + + int32 best_final_state; + forward_cost_.Min(&best_final_state); + + if (lag_nccf_.size() != static_cast(num_frames)) + lag_nccf_.resize(num_frames); + + frame_info_.back()->SetBestState(best_final_state, lag_nccf_); + frames_latency_ = + frame_info_.back()->ComputeLatency(opts_.max_frames_latency); + for (size_t i = 0; i < nccf_info_.size(); i++) + delete nccf_info_[i]; + nccf_info_.clear(); +} + +OnlinePitchFeatureImpl::~OnlinePitchFeatureImpl() { + delete nccf_resampler_; + delete signal_resampler_; + for (size_t i = 0; i < frame_info_.size(); i++) + delete frame_info_[i]; + for (size_t i = 0; i < nccf_info_.size(); i++) + delete nccf_info_[i]; +} + +void OnlinePitchFeatureImpl::AcceptWaveform( + BaseFloat sampling_rate, + const VectorBase &wave) { + // flush out the last few samples of input waveform only if input_finished_ == + // true. + const bool flush = input_finished_; + + Vector downsampled_wave; + signal_resampler_->Resample(wave, flush, &downsampled_wave); + + // these variables will be used to compute the root-mean-square value of the + // signal for the ballast term. + double cur_sumsq = signal_sumsq_, cur_sum = signal_sum_; + int64 cur_num_samp = downsampled_samples_processed_, + prev_frame_end_sample = 0; + if (!opts_.nccf_ballast_online) { + cur_sumsq += VecVec(downsampled_wave, downsampled_wave); + cur_sum += downsampled_wave.Sum(); + cur_num_samp += downsampled_wave.Dim(); + } + + // end_frame is the total number of frames we can now process, including + // previously processed ones. + int32 end_frame = NumFramesAvailable( + downsampled_samples_processed_ + downsampled_wave.Dim(), opts_.snip_edges); + // "start_frame" is the first frame-index we process + int32 start_frame = frame_info_.size() - 1, + num_new_frames = end_frame - start_frame; + + if (num_new_frames == 0) { + UpdateRemainder(downsampled_wave); + return; + // continuing to the rest of the code would generate + // an error when sizing matrices with zero rows, and + // anyway is a waste of time. + } + + int32 num_measured_lags = nccf_last_lag_ + 1 - nccf_first_lag_, + num_resampled_lags = lags_.Dim(), + frame_shift = opts_.NccfWindowShift(), + basic_frame_length = opts_.NccfWindowSize(), + full_frame_length = basic_frame_length + nccf_last_lag_; + + Vector window(full_frame_length), + inner_prod(num_measured_lags), + norm_prod(num_measured_lags); + Matrix nccf_pitch(num_new_frames, num_measured_lags), + nccf_pov(num_new_frames, num_measured_lags); + + Vector cur_forward_cost(num_resampled_lags); + + + // Because the resampling of the NCCF is more efficient when grouped together, + // we first compute the NCCF for all frames, then resample as a matrix, then + // do the Viterbi [that happens inside the constructor of PitchFrameInfo]. + + for (int32 frame = start_frame; frame < end_frame; frame++) { + // start_sample is index into the whole wave, not just this part. + int64 start_sample; + if (opts_.snip_edges) { + // Usual case: offset starts at 0 + start_sample = static_cast(frame) * frame_shift; + } else { + // When we are not snipping the edges, the first offsets may be + // negative. In this case we will pad with zeros, it should not impact + // the pitch tracker. + start_sample = + static_cast((frame + 0.5) * frame_shift) - full_frame_length / 2; + } + ExtractFrame(downsampled_wave, start_sample, &window); + if (opts_.nccf_ballast_online) { + // use only up to end of current frame to compute root-mean-square value. + // end_sample will be the sample-index into "downsampled_wave", so + // not really comparable to start_sample. + int64 end_sample = start_sample + full_frame_length - + downsampled_samples_processed_; + KALDI_ASSERT(end_sample > 0); // or should have processed this frame last + // time. Note: end_sample is one past last + // sample. + if (end_sample > downsampled_wave.Dim()) { + KALDI_ASSERT(input_finished_); + end_sample = downsampled_wave.Dim(); + } + SubVector new_part(downsampled_wave, prev_frame_end_sample, + end_sample - prev_frame_end_sample); + cur_num_samp += new_part.Dim(); + cur_sumsq += VecVec(new_part, new_part); + cur_sum += new_part.Sum(); + prev_frame_end_sample = end_sample; + } + double mean_square = cur_sumsq / cur_num_samp - + pow(cur_sum / cur_num_samp, 2.0); + + ComputeCorrelation(window, nccf_first_lag_, nccf_last_lag_, + basic_frame_length, &inner_prod, &norm_prod); + double nccf_ballast_pov = 0.0, + nccf_ballast_pitch = pow(mean_square * basic_frame_length, 2) * + opts_.nccf_ballast, + avg_norm_prod = norm_prod.Sum() / norm_prod.Dim(); + SubVector nccf_pitch_row(nccf_pitch, frame - start_frame); + ComputeNccf(inner_prod, norm_prod, nccf_ballast_pitch, + &nccf_pitch_row); + SubVector nccf_pov_row(nccf_pov, frame - start_frame); + ComputeNccf(inner_prod, norm_prod, nccf_ballast_pov, + &nccf_pov_row); + if (frame < opts_.recompute_frame) + nccf_info_.push_back(new NccfInfo(avg_norm_prod, mean_square)); + } + + Matrix nccf_pitch_resampled(num_new_frames, num_resampled_lags); + nccf_resampler_->Resample(nccf_pitch, &nccf_pitch_resampled); + nccf_pitch.Resize(0, 0); // no longer needed. + Matrix nccf_pov_resampled(num_new_frames, num_resampled_lags); + nccf_resampler_->Resample(nccf_pov, &nccf_pov_resampled); + nccf_pov.Resize(0, 0); // no longer needed. + + // We've finished dealing with the waveform so we can call UpdateRemainder + // now; we need to call it before we possibly call RecomputeBacktraces() + // below, which is why we don't do it at the very end. + UpdateRemainder(downsampled_wave); + + std::vector > index_info; + + for (int32 frame = start_frame; frame < end_frame; frame++) { + int32 frame_idx = frame - start_frame; + PitchFrameInfo *prev_info = frame_info_.back(), + *cur_info = new PitchFrameInfo(prev_info); + cur_info->SetNccfPov(nccf_pov_resampled.Row(frame_idx)); + cur_info->ComputeBacktraces(opts_, nccf_pitch_resampled.Row(frame_idx), + lags_, forward_cost_, &index_info, + &cur_forward_cost); + forward_cost_.Swap(&cur_forward_cost); + // Renormalize forward_cost so smallest element is zero. + BaseFloat remainder = forward_cost_.Min(); + forward_cost_remainder_ += remainder; + forward_cost_.Add(-remainder); + frame_info_.push_back(cur_info); + if (frame < opts_.recompute_frame) + nccf_info_[frame]->nccf_pitch_resampled = + nccf_pitch_resampled.Row(frame_idx); + if (frame == opts_.recompute_frame - 1 && !opts_.nccf_ballast_online) + RecomputeBacktraces(); + } + + // Trace back the best-path. + int32 best_final_state; + forward_cost_.Min(&best_final_state); + lag_nccf_.resize(frame_info_.size() - 1); // will keep any existing data. + frame_info_.back()->SetBestState(best_final_state, lag_nccf_); + frames_latency_ = + frame_info_.back()->ComputeLatency(opts_.max_frames_latency); + KALDI_VLOG(4) << "Latency is " << frames_latency_; +} + + + +// Some functions that forward from OnlinePitchFeature to +// OnlinePitchFeatureImpl. +int32 OnlinePitchFeature::NumFramesReady() const { + return impl_->NumFramesReady(); +} + +OnlinePitchFeature::OnlinePitchFeature(const PitchExtractionOptions &opts) + :impl_(new OnlinePitchFeatureImpl(opts)) { } + +bool OnlinePitchFeature::IsLastFrame(int32 frame) const { + return impl_->IsLastFrame(frame); +} + +BaseFloat OnlinePitchFeature::FrameShiftInSeconds() const { + return impl_->FrameShiftInSeconds(); +} + +void OnlinePitchFeature::GetFrame(int32 frame, VectorBase *feat) { + impl_->GetFrame(frame, feat); +} + +void OnlinePitchFeature::AcceptWaveform( + BaseFloat sampling_rate, + const VectorBase &waveform) { + impl_->AcceptWaveform(sampling_rate, waveform); +} + +void OnlinePitchFeature::InputFinished() { + impl_->InputFinished(); +} + +OnlinePitchFeature::~OnlinePitchFeature() { + delete impl_; +} + + +/** + This function is called from ComputeKaldiPitch when the user + specifies opts.simulate_first_pass_online == true. It gives + the "first-pass" version of the features, which you would get + on the first decoding pass in an online setting. These may + differ slightly from the final features due to both the + way the Viterbi traceback works (this is affected by + opts.max_frames_latency), and the online way we compute + the average signal energy. +*/ +void ComputeKaldiPitchFirstPass( + const PitchExtractionOptions &opts, + const VectorBase &wave, + Matrix *output) { + + int32 cur_rows = 100; + Matrix feats(cur_rows, 2); + + OnlinePitchFeature pitch_extractor(opts); + KALDI_ASSERT(opts.frames_per_chunk > 0 && + "--simulate-first-pass-online option does not make sense " + "unless you specify --frames-per-chunk"); + + int32 cur_offset = 0, cur_frame = 0, samp_per_chunk = + opts.frames_per_chunk * opts.samp_freq * opts.frame_shift_ms / 1000.0f; + + while (cur_offset < wave.Dim()) { + int32 num_samp = std::min(samp_per_chunk, wave.Dim() - cur_offset); + SubVector wave_chunk(wave, cur_offset, num_samp); + pitch_extractor.AcceptWaveform(opts.samp_freq, wave_chunk); + cur_offset += num_samp; + if (cur_offset == wave.Dim()) + pitch_extractor.InputFinished(); + // Get each frame as soon as it is ready. + for (; cur_frame < pitch_extractor.NumFramesReady(); cur_frame++) { + if (cur_frame >= cur_rows) { + cur_rows *= 2; + feats.Resize(cur_rows, 2, kCopyData); + } + SubVector row(feats, cur_frame); + pitch_extractor.GetFrame(cur_frame, &row); + } + } + if (cur_frame == 0) { + KALDI_WARN << "No features output since wave file too short"; + output->Resize(0, 0); + } else { + *output = feats.RowRange(0, cur_frame); + } +} + + + +void ComputeKaldiPitch(const PitchExtractionOptions &opts, + const VectorBase &wave, + Matrix *output) { + if (opts.simulate_first_pass_online) { + ComputeKaldiPitchFirstPass(opts, wave, output); + return; + } + OnlinePitchFeature pitch_extractor(opts); + + if (opts.frames_per_chunk == 0) { + pitch_extractor.AcceptWaveform(opts.samp_freq, wave); + } else { + // the user may set opts.frames_per_chunk for better compatibility with + // online operation. + KALDI_ASSERT(opts.frames_per_chunk > 0); + int32 cur_offset = 0, samp_per_chunk = + opts.frames_per_chunk * opts.samp_freq * opts.frame_shift_ms / 1000.0f; + while (cur_offset < wave.Dim()) { + int32 num_samp = std::min(samp_per_chunk, wave.Dim() - cur_offset); + SubVector wave_chunk(wave, cur_offset, num_samp); + pitch_extractor.AcceptWaveform(opts.samp_freq, wave_chunk); + cur_offset += num_samp; + } + } + pitch_extractor.InputFinished(); + int32 num_frames = pitch_extractor.NumFramesReady(); + if (num_frames == 0) { + KALDI_WARN << "No frames output in pitch extraction"; + output->Resize(0, 0); + return; + } + output->Resize(num_frames, 2); + for (int32 frame = 0; frame < num_frames; frame++) { + SubVector row(*output, frame); + pitch_extractor.GetFrame(frame, &row); + } +} + + +/* + This comment describes our invesigation of how much latency the + online-processing algorithm introduces, i.e. how many frames you would + typically have to wait until the traceback converges, if you were to set the + --max-frames-latency to a very large value. + + This was done on a couple of files of language-id data. + + /home/dpovey/kaldi-online/src/featbin/compute-kaldi-pitch-feats --frames-per-chunk=10 --max-frames-latency=100 --verbose=4 --sample-frequency=8000 --resample-frequency=2600 "scp:head -n 2 data/train/wav.scp |" ark:/dev/null 2>&1 | grep Latency | wc + 4871 24355 443991 + /home/dpovey/kaldi-online/src/featbin/compute-kaldi-pitch-feats --frames-per-chunk=10 --max-frames-latency=100 --verbose=4 --sample-frequency=8000 --resample-frequency=2600 "scp:head -n 2 data/train/wav.scp |" ark:/dev/null 2>&1 | grep Latency | grep 100 | wc + 1534 7670 141128 + +# as above, but with 50 instead of 10 in the --max-frames-latency and grep statements. + 2070 10350 188370 +# as above, but with 10 instead of 50. + 4067 20335 370097 + + This says that out of 4871 selected frames [we measured the latency every 10 + frames, since --frames-per-chunk=10], in 1534 frames (31%), the latency was + >= 100 frames, i.e. >= 1 second. Including the other numbers, we can see + that + + 31% of frames had latency >= 1 second + 42% of frames had latency >= 0.5 second + 83% of frames had latency >= 0.1 second. + + This doesn't necessarily mean that we actually have a latency of >= 1 second 31% of + the time when using these features, since by using the --max-frames-latency option + (default: 30 frames), it will limit the latency to, say, 0.3 seconds, and trace back + from the best current pitch. Most of the time this will probably cause no change in + the pitch traceback since the best current pitch is probably the "right" point to + trace back from. And anyway, in the online-decoding, we will most likely rescore + the features at the end anyway, and the traceback gets recomputed, so there will + be no inaccuracy (assuming the first-pass lattice had everything we needed). + + Probably the greater source of inaccuracy due to the online algorithm is the + online energy-normalization, which affects the NCCF-ballast term, and which, + for reasons of efficiency, we don't attempt to "correct" in a later rescoring + pass. This will make the most difference in the first few frames of the file, + before the first voicing, where it will tend to produce more pitch movement + than the offline version of the algorithm. +*/ + + +// Function to do data accumulation for on-line usage +template +inline void AppendVector(const VectorBase &src, Vector *dst) { + if (src.Dim() == 0) return; + dst->Resize(dst->Dim() + src.Dim(), kCopyData); + dst->Range(dst->Dim() - src.Dim(), src.Dim()).CopyFromVec(src); +} + +/** + Note on the implementation of OnlineProcessPitch: the + OnlineFeatureInterface allows random access to features (i.e. not necessarily + sequential order), so we need to support that. But we don't need to support + it very efficiently, and our implementation is most efficient if frames are + accessed in sequential order. + + Also note: we have to be a bit careful in this implementation because + the input features may change. That is: if we call + src_->GetFrame(t, &vec) from GetFrame(), we can't guarantee that a later + call to src_->GetFrame(t, &vec) from another GetFrame() will return the + same value. In fact, while designing this class we used some knowledge + of how the OnlinePitchFeature class works to minimize the amount of + re-querying we had to do. +*/ +OnlineProcessPitch::OnlineProcessPitch( + const ProcessPitchOptions &opts, + OnlineFeatureInterface *src): + opts_(opts), src_(src), + dim_ ((opts.add_pov_feature ? 1 : 0) + + (opts.add_normalized_log_pitch ? 1 : 0) + + (opts.add_delta_pitch ? 1 : 0) + + (opts.add_raw_log_pitch ? 1 : 0)) { + KALDI_ASSERT(dim_ > 0 && + " At least one of the pitch features should be chosen. " + "Check your post-process-pitch options."); + KALDI_ASSERT(src->Dim() == kRawFeatureDim && + "Input feature must be pitch feature (should have dimension 2)"); +} + + +void OnlineProcessPitch::GetFrame(int32 frame, + VectorBase *feat) { + int32 frame_delayed = frame < opts_.delay ? 0 : frame - opts_.delay; + KALDI_ASSERT(feat->Dim() == dim_ && + frame_delayed < NumFramesReady()); + int32 index = 0; + if (opts_.add_pov_feature) + (*feat)(index++) = GetPovFeature(frame_delayed); + if (opts_.add_normalized_log_pitch) + (*feat)(index++) = GetNormalizedLogPitchFeature(frame_delayed); + if (opts_.add_delta_pitch) + (*feat)(index++) = GetDeltaPitchFeature(frame_delayed); + if (opts_.add_raw_log_pitch) + (*feat)(index++) = GetRawLogPitchFeature(frame_delayed); + KALDI_ASSERT(index == dim_); +} + +BaseFloat OnlineProcessPitch::GetPovFeature(int32 frame) const { + Vector tmp(kRawFeatureDim); + src_->GetFrame(frame, &tmp); // (NCCF, pitch) from pitch extractor + BaseFloat nccf = tmp(0); + return opts_.pov_scale * NccfToPovFeature(nccf) + + opts_.pov_offset; +} + +BaseFloat OnlineProcessPitch::GetDeltaPitchFeature(int32 frame) { + // Rather than computing the delta pitch directly in code here, + // which might seem easier, we accumulate a small window of features + // and call ComputeDeltas. This might seem like overkill; the reason + // we do it this way is to ensure that the end effects (at file + // beginning and end) are handled in a consistent way. + int32 context = opts_.delta_window; + int32 start_frame = std::max(0, frame - context), + end_frame = std::min(frame + context + 1, src_->NumFramesReady()), + frames_in_window = end_frame - start_frame; + Matrix feats(frames_in_window, 1), + delta_feats; + + for (int32 f = start_frame; f < end_frame; f++) + feats(f - start_frame, 0) = GetRawLogPitchFeature(f); + + DeltaFeaturesOptions delta_opts; + delta_opts.order = 1; + delta_opts.window = opts_.delta_window; + ComputeDeltas(delta_opts, feats, &delta_feats); + while (delta_feature_noise_.size() <= static_cast(frame)) { + delta_feature_noise_.push_back(RandGauss() * + opts_.delta_pitch_noise_stddev); + } + // note: delta_feats will have two columns, second contains deltas. + return (delta_feats(frame - start_frame, 1) + delta_feature_noise_[frame]) * + opts_.delta_pitch_scale; +} + +BaseFloat OnlineProcessPitch::GetRawLogPitchFeature(int32 frame) const { + Vector tmp(kRawFeatureDim); + src_->GetFrame(frame, &tmp); + BaseFloat pitch = tmp(1); + KALDI_ASSERT(pitch > 0); + return Log(pitch); +} + +BaseFloat OnlineProcessPitch::GetNormalizedLogPitchFeature(int32 frame) { + UpdateNormalizationStats(frame); + BaseFloat log_pitch = GetRawLogPitchFeature(frame), + avg_log_pitch = normalization_stats_[frame].sum_log_pitch_pov / + normalization_stats_[frame].sum_pov, + normalized_log_pitch = log_pitch - avg_log_pitch; + return normalized_log_pitch * opts_.pitch_scale; +} + + +// inline +void OnlineProcessPitch::GetNormalizationWindow(int32 t, + int32 src_frames_ready, + int32 *window_begin, + int32 *window_end) const { + int32 left_context = opts_.normalization_left_context; + int32 right_context = opts_.normalization_right_context; + *window_begin = std::max(0, t - left_context); + *window_end = std::min(t + right_context + 1, src_frames_ready); +} + + +// Makes sure the entry in normalization_stats_ for this frame is up to date; +// called from GetNormalizedLogPitchFeature. +// the cur_num_frames and input_finished variables are needed because the +// pitch features for a given frame may change as we see more data. +void OnlineProcessPitch::UpdateNormalizationStats(int32 frame) { + KALDI_ASSERT(frame >= 0); + if (normalization_stats_.size() <= frame) + normalization_stats_.resize(frame + 1); + int32 cur_num_frames = src_->NumFramesReady(); + bool input_finished = src_->IsLastFrame(cur_num_frames - 1); + + NormalizationStats &this_stats = normalization_stats_[frame]; + if (this_stats.cur_num_frames == cur_num_frames && + this_stats.input_finished == input_finished) { + // Stats are fully up-to-date. + return; + } + int32 this_window_begin, this_window_end; + GetNormalizationWindow(frame, cur_num_frames, + &this_window_begin, &this_window_end); + + if (frame > 0) { + const NormalizationStats &prev_stats = normalization_stats_[frame - 1]; + if (prev_stats.cur_num_frames == cur_num_frames && + prev_stats.input_finished == input_finished) { + // we'll derive this_stats efficiently from prev_stats. + // Checking that cur_num_frames and input_finished have not changed + // ensures that the underlying features will not have changed. + this_stats = prev_stats; + int32 prev_window_begin, prev_window_end; + GetNormalizationWindow(frame - 1, cur_num_frames, + &prev_window_begin, &prev_window_end); + if (this_window_begin != prev_window_begin) { + KALDI_ASSERT(this_window_begin == prev_window_begin + 1); + Vector tmp(kRawFeatureDim); + src_->GetFrame(prev_window_begin, &tmp); + BaseFloat accurate_pov = NccfToPov(tmp(0)), + log_pitch = Log(tmp(1)); + this_stats.sum_pov -= accurate_pov; + this_stats.sum_log_pitch_pov -= accurate_pov * log_pitch; + } + if (this_window_end != prev_window_end) { + KALDI_ASSERT(this_window_end == prev_window_end + 1); + Vector tmp(kRawFeatureDim); + src_->GetFrame(prev_window_end, &tmp); + BaseFloat accurate_pov = NccfToPov(tmp(0)), + log_pitch = Log(tmp(1)); + this_stats.sum_pov += accurate_pov; + this_stats.sum_log_pitch_pov += accurate_pov * log_pitch; + } + return; + } + } + // The way we do it here is not the most efficient way to do it; + // we'll see if it becomes a problem. The issue is we have to redo + // this computation from scratch each time we process a new chunk, which + // may be a little inefficient if the chunk-size is very small. + this_stats.cur_num_frames = cur_num_frames; + this_stats.input_finished = input_finished; + this_stats.sum_pov = 0.0; + this_stats.sum_log_pitch_pov = 0.0; + Vector tmp(kRawFeatureDim); + for (int32 f = this_window_begin; f < this_window_end; f++) { + src_->GetFrame(f, &tmp); + BaseFloat accurate_pov = NccfToPov(tmp(0)), + log_pitch = Log(tmp(1)); + this_stats.sum_pov += accurate_pov; + this_stats.sum_log_pitch_pov += accurate_pov * log_pitch; + } +} + +int32 OnlineProcessPitch::NumFramesReady() const { + int32 src_frames_ready = src_->NumFramesReady(); + if (src_frames_ready == 0) { + return 0; + } else if (src_->IsLastFrame(src_frames_ready - 1)) { + return src_frames_ready + opts_.delay; + } else { + return std::max(0, src_frames_ready - + opts_.normalization_right_context + opts_.delay); + } +} + +void ProcessPitch(const ProcessPitchOptions &opts, + const MatrixBase &input, + Matrix *output) { + OnlineMatrixFeature pitch_feat(input); + + OnlineProcessPitch online_process_pitch(opts, &pitch_feat); + + output->Resize(online_process_pitch.NumFramesReady(), + online_process_pitch.Dim()); + for (int32 t = 0; t < online_process_pitch.NumFramesReady(); t++) { + SubVector row(*output, t); + online_process_pitch.GetFrame(t, &row); + } +} + + +void ComputeAndProcessKaldiPitch( + const PitchExtractionOptions &pitch_opts, + const ProcessPitchOptions &process_opts, + const VectorBase &wave, + Matrix *output) { + + OnlinePitchFeature pitch_extractor(pitch_opts); + + if (pitch_opts.simulate_first_pass_online) { + KALDI_ASSERT(pitch_opts.frames_per_chunk > 0 && + "--simulate-first-pass-online option does not make sense " + "unless you specify --frames-per-chunk"); + } + + OnlineProcessPitch post_process(process_opts, &pitch_extractor); + + int32 cur_rows = 100; + Matrix feats(cur_rows, post_process.Dim()); + + int32 cur_offset = 0, cur_frame = 0, + samp_per_chunk = pitch_opts.frames_per_chunk * + pitch_opts.samp_freq * pitch_opts.frame_shift_ms / 1000.0f; + + // We request the first-pass features as soon as they are available, + // regardless of whether opts.simulate_first_pass_online == true. If + // opts.simulate_first_pass_online == true this should + // not affect the features generated, but it helps us to test the code + // in a way that's closer to what online decoding would see. + + while (cur_offset < wave.Dim()) { + int32 num_samp; + if (samp_per_chunk > 0) + num_samp = std::min(samp_per_chunk, wave.Dim() - cur_offset); + else // user left opts.frames_per_chunk at zero. + num_samp = wave.Dim(); + SubVector wave_chunk(wave, cur_offset, num_samp); + pitch_extractor.AcceptWaveform(pitch_opts.samp_freq, wave_chunk); + cur_offset += num_samp; + if (cur_offset == wave.Dim()) + pitch_extractor.InputFinished(); + + // Get each frame as soon as it is ready. + for (; cur_frame < post_process.NumFramesReady(); cur_frame++) { + if (cur_frame >= cur_rows) { + cur_rows *= 2; + feats.Resize(cur_rows, post_process.Dim(), kCopyData); + } + SubVector row(feats, cur_frame); + post_process.GetFrame(cur_frame, &row); + } + } + + if (pitch_opts.simulate_first_pass_online) { + if (cur_frame == 0) { + KALDI_WARN << "No features output since wave file too short"; + output->Resize(0, 0); + } else { + *output = feats.RowRange(0, cur_frame); + } + } else { + // want the "final" features for second pass, so get them again. + output->Resize(post_process.NumFramesReady(), post_process.Dim()); + for (int32 frame = 0; frame < post_process.NumFramesReady(); frame++) { + SubVector row(*output, frame); + post_process.GetFrame(frame, &row); + } + } +} + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/pitch-functions.h b/speechx/speechx/kaldi/feat/pitch-functions.h new file mode 100644 index 00000000..70e85380 --- /dev/null +++ b/speechx/speechx/kaldi/feat/pitch-functions.h @@ -0,0 +1,450 @@ +// feat/pitch-functions.h + +// Copyright 2013 Pegah Ghahremani +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2014 Yanqing Sun, Junjie Wang, +// Daniel Povey, Korbinian Riedhammer +// Xin Lei + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_PITCH_FUNCTIONS_H_ +#define KALDI_FEAT_PITCH_FUNCTIONS_H_ + +#include +#include +#include +#include + +#include "base/kaldi-error.h" +#include "feat/mel-computations.h" +#include "itf/online-feature-itf.h" +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + +struct PitchExtractionOptions { + // FrameExtractionOptions frame_opts; + BaseFloat samp_freq; // sample frequency in hertz + BaseFloat frame_shift_ms; // in milliseconds. + BaseFloat frame_length_ms; // in milliseconds. + BaseFloat preemph_coeff; // Preemphasis coefficient. [use is deprecated.] + BaseFloat min_f0; // min f0 to search (Hz) + BaseFloat max_f0; // max f0 to search (Hz) + BaseFloat soft_min_f0; // Minimum f0, applied in soft way, must not + // exceed min-f0 + BaseFloat penalty_factor; // cost factor for FO change + BaseFloat lowpass_cutoff; // cutoff frequency for Low pass filter + BaseFloat resample_freq; // Integer that determines filter width when + // upsampling NCCF + BaseFloat delta_pitch; // the pitch tolerance in pruning lags + BaseFloat nccf_ballast; // Increasing this factor reduces NCCF for + // quiet frames, helping ensure pitch + // continuity in unvoiced region + int32 lowpass_filter_width; // Integer that determines filter width of + // lowpass filter + int32 upsample_filter_width; // Integer that determines filter width when + // upsampling NCCF + + // Below are newer config variables, not present in the original paper, + // that relate to the online pitch extraction algorithm. + + // The maximum number of frames of latency that we allow the pitch-processing + // to introduce, for online operation. If you set this to a large value, + // there would be no inaccuracy from the Viterbi traceback (but it might make + // you wait to see the pitch). This is not very relevant for the online + // operation: normalization-right-context is more relevant, you + // can just leave this value at zero. + int32 max_frames_latency; + + // Only relevant for the function ComputeKaldiPitch which is called by + // compute-kaldi-pitch-feats. If nonzero, we provide the input as chunks of + // this size. This affects the energy normalization which has a small effect + // on the resulting features, especially at the beginning of a file. For best + // compatibility with online operation (e.g. if you plan to train models for + // the online-deocding setup), you might want to set this to a small value, + // like one frame. + int32 frames_per_chunk; + + // Only relevant for the function ComputeKaldiPitch which is called by + // compute-kaldi-pitch-feats, and only relevant if frames_per_chunk is + // nonzero. If true, it will query the features as soon as they are + // available, which simulates the first-pass features you would get in online + // decoding. If false, the features you will get will be the same as those + // available at the end of the utterance, after InputFinished() has been + // called: e.g. during lattice rescoring. + bool simulate_first_pass_online; + + // Only relevant for online operation or when emulating online operation + // (e.g. when setting frames_per_chunk). This is the frame-index on which we + // recompute the NCCF (e.g. frame-index 500 = after 5 seconds); if the + // segment ends before this we do it when the segment ends. We do this by + // re-computing the signal average energy, which affects the NCCF via the + // "ballast term", scaling the resampled NCCF by a factor derived from the + // average change in the "ballast term", and re-doing the backtrace + // computation. Making this infinity would be the most exact, but would + // introduce unwanted latency at the end of long utterances, for little + // benefit. + int32 recompute_frame; + + // This is a "hidden config" used only for testing the online pitch + // extraction. If true, we compute the signal root-mean-squared for the + // ballast term, only up to the current frame, rather than the end of the + // current chunk of signal. This makes the output insensitive to the + // chunking, which is useful for testing purposes. + bool nccf_ballast_online; + bool snip_edges; + PitchExtractionOptions(): + samp_freq(16000), + frame_shift_ms(10.0), + frame_length_ms(25.0), + preemph_coeff(0.0), + min_f0(50), + max_f0(400), + soft_min_f0(10.0), + penalty_factor(0.1), + lowpass_cutoff(1000), + resample_freq(4000), + delta_pitch(0.005), + nccf_ballast(7000), + lowpass_filter_width(1), + upsample_filter_width(5), + max_frames_latency(0), + frames_per_chunk(0), + simulate_first_pass_online(false), + recompute_frame(500), + nccf_ballast_online(false), + snip_edges(true) { } + + void Register(OptionsItf *opts) { + opts->Register("sample-frequency", &samp_freq, + "Waveform data sample frequency (must match the waveform " + "file, if specified there)"); + opts->Register("frame-length", &frame_length_ms, "Frame length in " + "milliseconds"); + opts->Register("frame-shift", &frame_shift_ms, "Frame shift in " + "milliseconds"); + opts->Register("preemphasis-coefficient", &preemph_coeff, + "Coefficient for use in signal preemphasis (deprecated)"); + opts->Register("min-f0", &min_f0, + "min. F0 to search for (Hz)"); + opts->Register("max-f0", &max_f0, + "max. F0 to search for (Hz)"); + opts->Register("soft-min-f0", &soft_min_f0, + "Minimum f0, applied in soft way, must not exceed min-f0"); + opts->Register("penalty-factor", &penalty_factor, + "cost factor for FO change."); + opts->Register("lowpass-cutoff", &lowpass_cutoff, + "cutoff frequency for LowPass filter (Hz) "); + opts->Register("resample-frequency", &resample_freq, + "Frequency that we down-sample the signal to. Must be " + "more than twice lowpass-cutoff"); + opts->Register("delta-pitch", &delta_pitch, + "Smallest relative change in pitch that our algorithm " + "measures"); + opts->Register("nccf-ballast", &nccf_ballast, + "Increasing this factor reduces NCCF for quiet frames"); + opts->Register("nccf-ballast-online", &nccf_ballast_online, + "This is useful mainly for debug; it affects how the NCCF " + "ballast is computed."); + opts->Register("lowpass-filter-width", &lowpass_filter_width, + "Integer that determines filter width of " + "lowpass filter, more gives sharper filter"); + opts->Register("upsample-filter-width", &upsample_filter_width, + "Integer that determines filter width when upsampling NCCF"); + opts->Register("frames-per-chunk", &frames_per_chunk, "Only relevant for " + "offline pitch extraction (e.g. compute-kaldi-pitch-feats), " + "you can set it to a small nonzero value, such as 10, for " + "better feature compatibility with online decoding (affects " + "energy normalization in the algorithm)"); + opts->Register("simulate-first-pass-online", &simulate_first_pass_online, + "If true, compute-kaldi-pitch-feats will output features " + "that correspond to what an online decoder would see in the " + "first pass of decoding-- not the final version of the " + "features, which is the default. Relevant if " + "--frames-per-chunk > 0"); + opts->Register("recompute-frame", &recompute_frame, "Only relevant for " + "online pitch extraction, or for compatibility with online " + "pitch extraction. A non-critical parameter; the frame at " + "which we recompute some of the forward pointers, after " + "revising our estimate of the signal energy. Relevant if" + "--frames-per-chunk > 0"); + opts->Register("max-frames-latency", &max_frames_latency, "Maximum number " + "of frames of latency that we allow pitch tracking to " + "introduce into the feature processing (affects output only " + "if --frames-per-chunk > 0 and " + "--simulate-first-pass-online=true"); + opts->Register("snip-edges", &snip_edges, "If this is set to false, the " + "incomplete frames near the ending edge won't be snipped, " + "so that the number of frames is the file size divided by " + "the frame-shift. This makes different types of features " + "give the same number of frames."); + } + /// Returns the window-size in samples, after resampling. This is the + /// "basic window size", not the full window size after extending by max-lag. + // Because of floating point representation, it is more reliable to divide + // by 1000 instead of multiplying by 0.001, but it is a bit slower. + int32 NccfWindowSize() const { + return static_cast(resample_freq * frame_length_ms / 1000.0); + } + /// Returns the window-shift in samples, after resampling. + int32 NccfWindowShift() const { + return static_cast(resample_freq * frame_shift_ms / 1000.0); + } +}; + +struct ProcessPitchOptions { + BaseFloat pitch_scale; // the final normalized-log-pitch feature is scaled + // with this value + BaseFloat pov_scale; // the final POV feature is scaled with this value + BaseFloat pov_offset; // An offset that can be added to the final POV + // feature (useful for online-decoding, where we don't + // do CMN to the pitch-derived features. + + BaseFloat delta_pitch_scale; + BaseFloat delta_pitch_noise_stddev; // stddev of noise we add to delta-pitch + int32 normalization_left_context; // left-context used for sliding-window + // normalization + int32 normalization_right_context; // this should be reduced in online + // decoding to reduce latency + + int32 delta_window; + int32 delay; + + bool add_pov_feature; + bool add_normalized_log_pitch; + bool add_delta_pitch; + bool add_raw_log_pitch; + + ProcessPitchOptions() : + pitch_scale(2.0), + pov_scale(2.0), + pov_offset(0.0), + delta_pitch_scale(10.0), + delta_pitch_noise_stddev(0.005), + normalization_left_context(75), + normalization_right_context(75), + delta_window(2), + delay(0), + add_pov_feature(true), + add_normalized_log_pitch(true), + add_delta_pitch(true), + add_raw_log_pitch(false) { } + + + void Register(ParseOptions *opts) { + opts->Register("pitch-scale", &pitch_scale, + "Scaling factor for the final normalized log-pitch value"); + opts->Register("pov-scale", &pov_scale, + "Scaling factor for final POV (probability of voicing) " + "feature"); + opts->Register("pov-offset", &pov_offset, + "This can be used to add an offset to the POV feature. " + "Intended for use in online decoding as a substitute for " + " CMN."); + opts->Register("delta-pitch-scale", &delta_pitch_scale, + "Term to scale the final delta log-pitch feature"); + opts->Register("delta-pitch-noise-stddev", &delta_pitch_noise_stddev, + "Standard deviation for noise we add to the delta log-pitch " + "(before scaling); should be about the same as delta-pitch " + "option to pitch creation. The purpose is to get rid of " + "peaks in the delta-pitch caused by discretization of pitch " + "values."); + opts->Register("normalization-left-context", &normalization_left_context, + "Left-context (in frames) for moving window normalization"); + opts->Register("normalization-right-context", &normalization_right_context, + "Right-context (in frames) for moving window normalization"); + opts->Register("delta-window", &delta_window, + "Number of frames on each side of central frame, to use for " + "delta window."); + opts->Register("delay", &delay, + "Number of frames by which the pitch information is " + "delayed."); + opts->Register("add-pov-feature", &add_pov_feature, + "If true, the warped NCCF is added to output features"); + opts->Register("add-normalized-log-pitch", &add_normalized_log_pitch, + "If true, the log-pitch with POV-weighted mean subtraction " + "over 1.5 second window is added to output features"); + opts->Register("add-delta-pitch", &add_delta_pitch, + "If true, time derivative of log-pitch is added to output " + "features"); + opts->Register("add-raw-log-pitch", &add_raw_log_pitch, + "If true, log(pitch) is added to output features"); + } +}; + + +// We don't want to expose the pitch-extraction internals here as it's +// quite complex, so we use a private implementation. +class OnlinePitchFeatureImpl; + + +// Note: to start on a new waveform, just construct a new version +// of this object. +class OnlinePitchFeature: public OnlineBaseFeature { + public: + explicit OnlinePitchFeature(const PitchExtractionOptions &opts); + + virtual int32 Dim() const { return 2; /* (NCCF, pitch) */ } + + virtual int32 NumFramesReady() const; + + virtual BaseFloat FrameShiftInSeconds() const; + + virtual bool IsLastFrame(int32 frame) const; + + /// Outputs the two-dimensional feature consisting of (pitch, NCCF). You + /// should probably post-process this using class OnlineProcessPitch. + virtual void GetFrame(int32 frame, VectorBase *feat); + + virtual void AcceptWaveform(BaseFloat sampling_rate, + const VectorBase &waveform); + + virtual void InputFinished(); + + virtual ~OnlinePitchFeature(); + + private: + OnlinePitchFeatureImpl *impl_; +}; + + +/// This online-feature class implements post processing of pitch features. +/// Inputs are original 2 dims (nccf, pitch). It can produce various +/// kinds of outputs, using the default options it will be (pov-feature, +/// normalized-log-pitch, delta-log-pitch). +class OnlineProcessPitch: public OnlineFeatureInterface { + public: + virtual int32 Dim() const { return dim_; } + + virtual bool IsLastFrame(int32 frame) const { + if (frame <= -1) + return src_->IsLastFrame(-1); + else if (frame < opts_.delay) + return src_->IsLastFrame(-1) == true ? false : src_->IsLastFrame(0); + else + return src_->IsLastFrame(frame - opts_.delay); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const; + + virtual void GetFrame(int32 frame, VectorBase *feat); + + virtual ~OnlineProcessPitch() { } + + // Does not take ownership of "src". + OnlineProcessPitch(const ProcessPitchOptions &opts, + OnlineFeatureInterface *src); + + private: + enum { kRawFeatureDim = 2}; // anonymous enum to define a constant. + // kRawFeatureDim defines the dimension + // of the input: (nccf, pitch) + + ProcessPitchOptions opts_; + OnlineFeatureInterface *src_; + int32 dim_; // Output feature dimension, set in initializer. + + struct NormalizationStats { + int32 cur_num_frames; // value of src_->NumFramesReady() when + // "mean_pitch" was set. + bool input_finished; // true if input data was finished when + // "mean_pitch" was computed. + double sum_pov; // sum of pov over relevant range + double sum_log_pitch_pov; // sum of log(pitch) * pov over relevant range + + NormalizationStats(): cur_num_frames(-1), input_finished(false), + sum_pov(0.0), sum_log_pitch_pov(0.0) { } + }; + + std::vector delta_feature_noise_; + + std::vector normalization_stats_; + + /// Computes and returns the POV feature for this frame. + /// Called from GetFrame(). + inline BaseFloat GetPovFeature(int32 frame) const; + + /// Computes and returns the delta-log-pitch feature for this frame. + /// Called from GetFrame(). + inline BaseFloat GetDeltaPitchFeature(int32 frame); + + /// Computes and returns the raw log-pitch feature for this frame. + /// Called from GetFrame(). + inline BaseFloat GetRawLogPitchFeature(int32 frame) const; + + /// Computes and returns the mean-subtracted log-pitch feature for this frame. + /// Called from GetFrame(). + inline BaseFloat GetNormalizedLogPitchFeature(int32 frame); + + /// Computes the normalization window sizes. + inline void GetNormalizationWindow(int32 frame, + int32 src_frames_ready, + int32 *window_begin, + int32 *window_end) const; + + /// Makes sure the entry in normalization_stats_ for this frame is up to date; + /// called from GetNormalizedLogPitchFeature. + inline void UpdateNormalizationStats(int32 frame); +}; + + +/// This function extracts (pitch, NCCF) per frame, using the pitch extraction +/// method described in "A Pitch Extraction Algorithm Tuned for Automatic Speech +/// Recognition", Pegah Ghahremani, Bagher BabaAli, Daniel Povey, Korbinian +/// Riedhammer, Jan Trmal and Sanjeev Khudanpur, ICASSP 2014. The output will +/// have as many rows as there are frames, and two columns corresponding to +/// (NCCF, pitch) +void ComputeKaldiPitch(const PitchExtractionOptions &opts, + const VectorBase &wave, + Matrix *output); + +/// This function processes the raw (NCCF, pitch) quantities computed by +/// ComputeKaldiPitch, and processes them into features. By default it will +/// output three-dimensional features, (POV-feature, mean-subtracted-log-pitch, +/// delta-of-raw-pitch), but this is configurable in the options. The number of +/// rows of "output" will be the number of frames (rows) in "input", and the +/// number of columns will be the number of different types of features +/// requested (by default, 3; 4 is the max). The four config variables +/// --add-pov-feature, --add-normalized-log-pitch, --add-delta-pitch, +/// --add-raw-log-pitch determine which features we create; by default we create +/// the first three. +void ProcessPitch(const ProcessPitchOptions &opts, + const MatrixBase &input, + Matrix *output); + +/// This function combines ComputeKaldiPitch and ProcessPitch. The reason +/// why we need a separate function to do this is in order to be able to +/// accurately simulate the online pitch-processing, for testing and for +/// training models matched to the "first-pass" features. It is sensitive to +/// the variables in pitch_opts that relate to online processing, +/// i.e. max_frames_latency, frames_per_chunk, simulate_first_pass_online, +/// recompute_frame. +void ComputeAndProcessKaldiPitch(const PitchExtractionOptions &pitch_opts, + const ProcessPitchOptions &process_opts, + const VectorBase &wave, + Matrix *output); + + +/// @} End of "addtogroup feat" +} // namespace kaldi +#endif // KALDI_FEAT_PITCH_FUNCTIONS_H_ diff --git a/speechx/speechx/kaldi/feat/resample.cc b/speechx/speechx/kaldi/feat/resample.cc new file mode 100644 index 00000000..11f4c62b --- /dev/null +++ b/speechx/speechx/kaldi/feat/resample.cc @@ -0,0 +1,377 @@ +// feat/resample.cc + +// Copyright 2013 Pegah Ghahremani +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2014 Yanqing Sun, Junjie Wang +// 2014 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include "feat/feature-functions.h" +#include "matrix/matrix-functions.h" +#include "feat/resample.h" + +namespace kaldi { + + +LinearResample::LinearResample(int32 samp_rate_in_hz, + int32 samp_rate_out_hz, + BaseFloat filter_cutoff_hz, + int32 num_zeros): + samp_rate_in_(samp_rate_in_hz), + samp_rate_out_(samp_rate_out_hz), + filter_cutoff_(filter_cutoff_hz), + num_zeros_(num_zeros) { + KALDI_ASSERT(samp_rate_in_hz > 0.0 && + samp_rate_out_hz > 0.0 && + filter_cutoff_hz > 0.0 && + filter_cutoff_hz*2 <= samp_rate_in_hz && + filter_cutoff_hz*2 <= samp_rate_out_hz && + num_zeros > 0); + + // base_freq is the frequency of the repeating unit, which is the gcd + // of the input frequencies. + int32 base_freq = Gcd(samp_rate_in_, samp_rate_out_); + input_samples_in_unit_ = samp_rate_in_ / base_freq; + output_samples_in_unit_ = samp_rate_out_ / base_freq; + + SetIndexesAndWeights(); + Reset(); +} + +int64 LinearResample::GetNumOutputSamples(int64 input_num_samp, + bool flush) const { + // For exact computation, we measure time in "ticks" of 1.0 / tick_freq, + // where tick_freq is the least common multiple of samp_rate_in_ and + // samp_rate_out_. + int32 tick_freq = Lcm(samp_rate_in_, samp_rate_out_); + int32 ticks_per_input_period = tick_freq / samp_rate_in_; + + // work out the number of ticks in the time interval + // [ 0, input_num_samp/samp_rate_in_ ). + int64 interval_length_in_ticks = input_num_samp * ticks_per_input_period; + if (!flush) { + BaseFloat window_width = num_zeros_ / (2.0 * filter_cutoff_); + // To count the window-width in ticks we take the floor. This + // is because since we're looking for the largest integer num-out-samp + // that fits in the interval, which is open on the right, a reduction + // in interval length of less than a tick will never make a difference. + // For example, the largest integer in the interval [ 0, 2 ) and the + // largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one). + // So when we're subtracting the window-width we can ignore the fractional + // part. + int32 window_width_ticks = floor(window_width * tick_freq); + // The time-period of the output that we can sample gets reduced + // by the window-width (which is actually the distance from the + // center to the edge of the windowing function) if we're not + // "flushing the output". + interval_length_in_ticks -= window_width_ticks; + } + if (interval_length_in_ticks <= 0) + return 0; + int32 ticks_per_output_period = tick_freq / samp_rate_out_; + // Get the last output-sample in the closed interval, i.e. replacing [ ) with + // [ ]. Note: integer division rounds down. See + // http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of + // the notation. + int64 last_output_samp = interval_length_in_ticks / ticks_per_output_period; + // We need the last output-sample in the open interval, so if it takes us to + // the end of the interval exactly, subtract one. + if (last_output_samp * ticks_per_output_period == interval_length_in_ticks) + last_output_samp--; + // First output-sample index is zero, so the number of output samples + // is the last output-sample plus one. + int64 num_output_samp = last_output_samp + 1; + return num_output_samp; +} + +void LinearResample::SetIndexesAndWeights() { + first_index_.resize(output_samples_in_unit_); + weights_.resize(output_samples_in_unit_); + + double window_width = num_zeros_ / (2.0 * filter_cutoff_); + + for (int32 i = 0; i < output_samples_in_unit_; i++) { + double output_t = i / static_cast(samp_rate_out_); + double min_t = output_t - window_width, max_t = output_t + window_width; + // we do ceil on the min and floor on the max, because if we did it + // the other way around we would unnecessarily include indexes just + // outside the window, with zero coefficients. It's possible + // if the arguments to the ceil and floor expressions are integers + // (e.g. if filter_cutoff_ has an exact ratio with the sample rates), + // that we unnecessarily include something with a zero coefficient, + // but this is only a slight efficiency issue. + int32 min_input_index = ceil(min_t * samp_rate_in_), + max_input_index = floor(max_t * samp_rate_in_), + num_indices = max_input_index - min_input_index + 1; + first_index_[i] = min_input_index; + weights_[i].Resize(num_indices); + for (int32 j = 0; j < num_indices; j++) { + int32 input_index = min_input_index + j; + double input_t = input_index / static_cast(samp_rate_in_), + delta_t = input_t - output_t; + // sign of delta_t doesn't matter. + weights_[i](j) = FilterFunc(delta_t) / samp_rate_in_; + } + } +} + + +// inline +void LinearResample::GetIndexes(int64 samp_out, + int64 *first_samp_in, + int32 *samp_out_wrapped) const { + // A unit is the smallest nonzero amount of time that is an exact + // multiple of the input and output sample periods. The unit index + // is the answer to "which numbered unit we are in". + int64 unit_index = samp_out / output_samples_in_unit_; + // samp_out_wrapped is equal to samp_out % output_samples_in_unit_ + *samp_out_wrapped = static_cast(samp_out - + unit_index * output_samples_in_unit_); + *first_samp_in = first_index_[*samp_out_wrapped] + + unit_index * input_samples_in_unit_; +} + + +void LinearResample::Resample(const VectorBase &input, + bool flush, + Vector *output) { + int32 input_dim = input.Dim(); + int64 tot_input_samp = input_sample_offset_ + input_dim, + tot_output_samp = GetNumOutputSamples(tot_input_samp, flush); + + KALDI_ASSERT(tot_output_samp >= output_sample_offset_); + + output->Resize(tot_output_samp - output_sample_offset_); + + // samp_out is the index into the total output signal, not just the part + // of it we are producing here. + for (int64 samp_out = output_sample_offset_; + samp_out < tot_output_samp; + samp_out++) { + int64 first_samp_in; + int32 samp_out_wrapped; + GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped); + const Vector &weights = weights_[samp_out_wrapped]; + // first_input_index is the first index into "input" that we have a weight + // for. + int32 first_input_index = static_cast(first_samp_in - + input_sample_offset_); + BaseFloat this_output; + if (first_input_index >= 0 && + first_input_index + weights.Dim() <= input_dim) { + SubVector input_part(input, first_input_index, weights.Dim()); + this_output = VecVec(input_part, weights); + } else { // Handle edge cases. + this_output = 0.0; + for (int32 i = 0; i < weights.Dim(); i++) { + BaseFloat weight = weights(i); + int32 input_index = first_input_index + i; + if (input_index < 0 && input_remainder_.Dim() + input_index >= 0) { + this_output += weight * + input_remainder_(input_remainder_.Dim() + input_index); + } else if (input_index >= 0 && input_index < input_dim) { + this_output += weight * input(input_index); + } else if (input_index >= input_dim) { + // We're past the end of the input and are adding zero; should only + // happen if the user specified flush == true, or else we would not + // be trying to output this sample. + KALDI_ASSERT(flush); + } + } + } + int32 output_index = static_cast(samp_out - output_sample_offset_); + (*output)(output_index) = this_output; + } + + if (flush) { + Reset(); // Reset the internal state. + } else { + SetRemainder(input); + input_sample_offset_ = tot_input_samp; + output_sample_offset_ = tot_output_samp; + } +} + +void LinearResample::SetRemainder(const VectorBase &input) { + Vector old_remainder(input_remainder_); + // max_remainder_needed is the width of the filter from side to side, + // measured in input samples. you might think it should be half that, + // but you have to consider that you might be wanting to output samples + // that are "in the past" relative to the beginning of the latest + // input... anyway, storing more remainder than needed is not harmful. + int32 max_remainder_needed = ceil(samp_rate_in_ * num_zeros_ / + filter_cutoff_); + input_remainder_.Resize(max_remainder_needed); + for (int32 index = - input_remainder_.Dim(); index < 0; index++) { + // we interpret "index" as an offset from the end of "input" and + // from the end of input_remainder_. + int32 input_index = index + input.Dim(); + if (input_index >= 0) + input_remainder_(index + input_remainder_.Dim()) = input(input_index); + else if (input_index + old_remainder.Dim() >= 0) + input_remainder_(index + input_remainder_.Dim()) = + old_remainder(input_index + old_remainder.Dim()); + // else leave it at zero. + } +} + +void LinearResample::Reset() { + input_sample_offset_ = 0; + output_sample_offset_ = 0; + input_remainder_.Resize(0); +} + +/** Here, t is a time in seconds representing an offset from + the center of the windowed filter function, and FilterFunction(t) + returns the windowed filter function, described + in the header as h(t) = f(t)g(t), evaluated at t. +*/ +BaseFloat LinearResample::FilterFunc(BaseFloat t) const { + BaseFloat window, // raised-cosine (Hanning) window of width + // num_zeros_/2*filter_cutoff_ + filter; // sinc filter function + if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_)) + window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t)); + else + window = 0.0; // outside support of window function + if (t != 0) + filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t); + else + filter = 2 * filter_cutoff_; // limit of the function at t = 0 + return filter * window; +} + + +ArbitraryResample::ArbitraryResample( + int32 num_samples_in, BaseFloat samp_rate_in, + BaseFloat filter_cutoff, const Vector &sample_points, + int32 num_zeros): + num_samples_in_(num_samples_in), + samp_rate_in_(samp_rate_in), + filter_cutoff_(filter_cutoff), + num_zeros_(num_zeros) { + KALDI_ASSERT(num_samples_in > 0 && samp_rate_in > 0.0 && + filter_cutoff > 0.0 && + filter_cutoff * 2.0 <= samp_rate_in + && num_zeros > 0); + // set up weights_ and indices_. Please try to keep all functions short and + SetIndexes(sample_points); + SetWeights(sample_points); +} + + +void ArbitraryResample::Resample(const MatrixBase &input, + MatrixBase *output) const { + // each row of "input" corresponds to the data to resample; + // the corresponding row of "output" is the resampled data. + + KALDI_ASSERT(input.NumRows() == output->NumRows() && + input.NumCols() == num_samples_in_ && + output->NumCols() == weights_.size()); + + Vector output_col(output->NumRows()); + for (int32 i = 0; i < NumSamplesOut(); i++) { + SubMatrix input_part(input, 0, input.NumRows(), + first_index_[i], + weights_[i].Dim()); + const Vector &weight_vec(weights_[i]); + output_col.AddMatVec(1.0, input_part, + kNoTrans, weight_vec, 0.0); + output->CopyColFromVec(output_col, i); + } +} + +void ArbitraryResample::Resample(const VectorBase &input, + VectorBase *output) const { + KALDI_ASSERT(input.Dim() == num_samples_in_ && + output->Dim() == weights_.size()); + + int32 output_dim = output->Dim(); + for (int32 i = 0; i < output_dim; i++) { + SubVector input_part(input, first_index_[i], weights_[i].Dim()); + (*output)(i) = VecVec(input_part, weights_[i]); + } +} + +void ArbitraryResample::SetIndexes(const Vector &sample_points) { + int32 num_samples = sample_points.Dim(); + first_index_.resize(num_samples); + weights_.resize(num_samples); + BaseFloat filter_width = num_zeros_ / (2.0 * filter_cutoff_); + for (int32 i = 0; i < num_samples; i++) { + // the t values are in seconds. + BaseFloat t = sample_points(i), + t_min = t - filter_width, t_max = t + filter_width; + int32 index_min = ceil(samp_rate_in_ * t_min), + index_max = floor(samp_rate_in_ * t_max); + // the ceil on index min and the floor on index_max are because there + // is no point using indices just outside the window (coeffs would be zero). + if (index_min < 0) + index_min = 0; + if (index_max >= num_samples_in_) + index_max = num_samples_in_ - 1; + first_index_[i] = index_min; + weights_[i].Resize(index_max - index_min + 1); + } +} + +void ArbitraryResample::SetWeights(const Vector &sample_points) { + int32 num_samples_out = NumSamplesOut(); + for (int32 i = 0; i < num_samples_out; i++) { + for (int32 j = 0 ; j < weights_[i].Dim(); j++) { + BaseFloat delta_t = sample_points(i) - + (first_index_[i] + j) / samp_rate_in_; + // Include at this point the factor of 1.0 / samp_rate_in_ which + // appears in the math. + weights_[i](j) = FilterFunc(delta_t) / samp_rate_in_; + } + } +} + +/** Here, t is a time in seconds representing an offset from + the center of the windowed filter function, and FilterFunction(t) + returns the windowed filter function, described + in the header as h(t) = f(t)g(t), evaluated at t. +*/ +BaseFloat ArbitraryResample::FilterFunc(BaseFloat t) const { + BaseFloat window, // raised-cosine (Hanning) window of width + // num_zeros_/2*filter_cutoff_ + filter; // sinc filter function + if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_)) + window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t)); + else + window = 0.0; // outside support of window function + if (t != 0.0) + filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t); + else + filter = 2.0 * filter_cutoff_; // limit of the function at zero. + return filter * window; +} + +void ResampleWaveform(BaseFloat orig_freq, const VectorBase &wave, + BaseFloat new_freq, Vector *new_wave) { + BaseFloat min_freq = std::min(orig_freq, new_freq); + BaseFloat lowpass_cutoff = 0.99 * 0.5 * min_freq; + int32 lowpass_filter_width = 6; + LinearResample resampler(orig_freq, new_freq, + lowpass_cutoff, lowpass_filter_width); + resampler.Resample(wave, true, new_wave); +} +} // namespace kaldi diff --git a/speechx/speechx/kaldi/feat/resample.h b/speechx/speechx/kaldi/feat/resample.h new file mode 100644 index 00000000..e0b4688c --- /dev/null +++ b/speechx/speechx/kaldi/feat/resample.h @@ -0,0 +1,287 @@ +// feat/resample.h + +// Copyright 2013 Pegah Ghahremani +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2014 Yanqing Sun, Junjie Wang +// 2014 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_FEAT_RESAMPLE_H_ +#define KALDI_FEAT_RESAMPLE_H_ + +#include +#include +#include +#include + + +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + +/** + \file[resample.h] + + This header contains declarations of classes for resampling signals. The + normal cases of resampling a signal are upsampling and downsampling + (increasing and decreasing the sample rate of a signal, respectively), + although the ArbitraryResample class allows a more generic case where + we want to get samples of a signal at uneven intervals (for instance, + log-spaced). + + The input signal is always evenly spaced, say sampled with frequency S, and + we assume the original signal was band-limited to S/2 or lower. The n'th + input sample x_n (with n = 0, 1, ...) is interpreted as the original + signal's value at time n/S. + + For resampling, it is convenient to view the input signal as a + continuous function x(t) of t, where each sample x_n becomes a delta function + with magnitude x_n/S, at time n/S. If we band limit this to the Nyquist + frequency S/2, we can show that this is the same as the original signal + that was sampled. [assuming the original signal was periodic and band + limited.] In general we want to bandlimit to lower than S/2, because + we don't have a perfect filter and also because if we want to resample + at a lower frequency than S, we need to bandlimit to below half of that. + Anyway, suppose we want to bandlimit to C, with 0 < C < S/2. The perfect + rectangular filter with cutoff C is the sinc function, + \f[ f(t) = 2C sinc(2Ct), \f] + where sinc is the normalized sinc function \f$ sinc(t) = sin(pi t) / (pi t) \f$, with + \f$ sinc(0) = 1 \f$. This is not a practical filter, though, because it has + infinite support. At the cost of less-than-perfect rolloff, we can choose + a suitable windowing function g(t), and use f(t) g(t) as the filter. For + a windowing function we choose raised-cosine (Hanning) window with support + on [-w/2C, w/2C], where w >= 2 is an integer chosen by the user. w = 1 + means we window the sinc function out to its first zero on the left and right, + w = 2 means the second zero, and so on; we normally choose w to be at least two. + We call this num_zeros, not w, in the code. + + Convolving the signal x(t) with this windowed filter h(t) = f(t)g(t) and evaluating the resulting + signal s(t) at an arbitrary time t is easy: we have + \f[ s(t) = 1/S \sum_n x_n h(t - n/S) \f]. + (note: the sign of t - n/S might be wrong, but it doesn't matter as the filter + and window are symmetric). + This is true for arbitrary values of t. What the class ArbitraryResample does + is to allow you to evaluate the signal for specified values of t. +*/ + + +/** + Class ArbitraryResample allows you to resample a signal (assumed zero outside + the sample region, not periodic) at arbitrary specified time values, which + don't have to be linearly spaced. The low-pass filter cutoff + "filter_cutoff_hz" should be less than half the sample rate; + "num_zeros" should probably be at least two preferably more; higher numbers give + sharper filters but will be less efficient. +*/ +class ArbitraryResample { + public: + ArbitraryResample(int32 num_samples_in, + BaseFloat samp_rate_hz, + BaseFloat filter_cutoff_hz, + const Vector &sample_points_secs, + int32 num_zeros); + + int32 NumSamplesIn() const { return num_samples_in_; } + + int32 NumSamplesOut() const { return weights_.size(); } + + /// This function does the resampling. + /// input.NumRows() and output.NumRows() should be equal + /// and nonzero. + /// input.NumCols() should equal NumSamplesIn() + /// and output.NumCols() should equal NumSamplesOut(). + void Resample(const MatrixBase &input, + MatrixBase *output) const; + + /// This version of the Resample function processes just + /// one vector. + void Resample(const VectorBase &input, + VectorBase *output) const; + private: + void SetIndexes(const Vector &sample_points); + + void SetWeights(const Vector &sample_points); + + BaseFloat FilterFunc(BaseFloat t) const; + + int32 num_samples_in_; + BaseFloat samp_rate_in_; + BaseFloat filter_cutoff_; + int32 num_zeros_; + + std::vector first_index_; // The first input-sample index that we sum + // over, for this output-sample index. + std::vector > weights_; +}; + + +/** + LinearResample is a special case of ArbitraryResample, where we want to + resample a signal at linearly spaced intervals (this means we want to + upsample or downsample the signal). It is more efficient than + ArbitraryResample because we can construct it just once. + + We require that the input and output sampling rate be specified as + integers, as this is an easy way to specify that their ratio be rational. +*/ + +class LinearResample { + public: + /// Constructor. We make the input and output sample rates integers, because + /// we are going to need to find a common divisor. This should just remind + /// you that they need to be integers. The filter cutoff needs to be less + /// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. num_zeros + /// controls the sharpness of the filter, more == sharper but less efficient. + /// We suggest around 4 to 10 for normal use. + LinearResample(int32 samp_rate_in_hz, + int32 samp_rate_out_hz, + BaseFloat filter_cutoff_hz, + int32 num_zeros); + + /// This function does the resampling. If you call it with flush == true and + /// you have never called it with flush == false, it just resamples the input + /// signal (it resizes the output to a suitable number of samples). + /// + /// You can also use this function to process a signal a piece at a time. + /// suppose you break it into piece1, piece2, ... pieceN. You can call + /// \code{.cc} + /// Resample(piece1, &output1, false); + /// Resample(piece2, &output2, false); + /// Resample(piece3, &output3, true); + /// \endcode + /// If you call it with flush == false, it won't output the last few samples + /// but will remember them, so that if you later give it a second piece of + /// the input signal it can process it correctly. + /// If your most recent call to the object was with flush == false, it will + /// have internal state; you can remove this by calling Reset(). + /// Empty input is acceptable. + void Resample(const VectorBase &input, + bool flush, + Vector *output); + + /// Calling the function Reset() resets the state of the object prior to + /// processing a new signal; it is only necessary if you have called + /// Resample(x, y, false) for some signal, leading to a remainder of the + /// signal being called, but then abandon processing the signal before calling + /// Resample(x, y, true) for the last piece. Call it unnecessarily between + /// signals will not do any harm. + void Reset(); + + //// Return the input and output sampling rates (for checks, for example) + inline int32 GetInputSamplingRate() { return samp_rate_in_; } + inline int32 GetOutputSamplingRate() { return samp_rate_out_; } + private: + /// This function outputs the number of output samples we will output + /// for a signal with "input_num_samp" input samples. If flush == true, + /// we return the largest n such that + /// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ), + /// and note that the interval is half-open. If flush == false, + /// define window_width as num_zeros / (2.0 * filter_cutoff_); + /// we return the largest n such that (n/samp_rate_out_) is in the interval + /// [ 0, input_num_samp/samp_rate_in_ - window_width ). + int64 GetNumOutputSamples(int64 input_num_samp, bool flush) const; + + + /// Given an output-sample index, this function outputs to *first_samp_in the + /// first input-sample index that we have a weight on (may be negative), + /// and to *samp_out_wrapped the index into weights_ where we can get the + /// corresponding weights on the input. + inline void GetIndexes(int64 samp_out, + int64 *first_samp_in, + int32 *samp_out_wrapped) const; + + void SetRemainder(const VectorBase &input); + + void SetIndexesAndWeights(); + + BaseFloat FilterFunc(BaseFloat) const; + + // The following variables are provided by the user. + int32 samp_rate_in_; + int32 samp_rate_out_; + BaseFloat filter_cutoff_; + int32 num_zeros_; + + int32 input_samples_in_unit_; ///< The number of input samples in the + ///< smallest repeating unit: num_samp_in_ = + ///< samp_rate_in_hz / Gcd(samp_rate_in_hz, + ///< samp_rate_out_hz) + int32 output_samples_in_unit_; ///< The number of output samples in the + ///< smallest repeating unit: num_samp_out_ = + ///< samp_rate_out_hz / Gcd(samp_rate_in_hz, + ///< samp_rate_out_hz) + + + /// The first input-sample index that we sum over, for this output-sample + /// index. May be negative; any truncation at the beginning is handled + /// separately. This is just for the first few output samples, but we can + /// extrapolate the correct input-sample index for arbitrary output samples. + std::vector first_index_; + + /// Weights on the input samples, for this output-sample index. + std::vector > weights_; + + // the following variables keep track of where we are in a particular signal, + // if it is being provided over multiple calls to Resample(). + + int64 input_sample_offset_; ///< The number of input samples we have + ///< already received for this signal + ///< (including anything in remainder_) + int64 output_sample_offset_; ///< The number of samples we have already + ///< output for this signal. + Vector input_remainder_; ///< A small trailing part of the + ///< previously seen input signal. +}; + +/** + Downsample or upsample a waveform. This is a convenience wrapper for the + class 'LinearResample'. + The low-pass filter cutoff used in 'LinearResample' is 0.99 of the Nyquist, + where the Nyquist is half of the minimum of (orig_freq, new_freq). The + resampling is done with a symmetric FIR filter with N_z (number of zeros) + as 6. + + We compared the downsampling results with those from the sox resampling + toolkit. + Sox's design is inspired by Laurent De Soras' paper, + https://ccrma.stanford.edu/~jos/resample/Implementation.html + + Note: we expect that while orig_freq and new_freq are of type BaseFloat, they + are actually required to have exact integer values (like 16000 or 8000) with + a ratio between them that can be expressed as a rational number with + reasonably small integer factors. +*/ +void ResampleWaveform(BaseFloat orig_freq, const VectorBase &wave, + BaseFloat new_freq, Vector *new_wave); + + +/// This function is deprecated. It is provided for backward compatibility, to avoid +/// breaking older code. +inline void DownsampleWaveForm(BaseFloat orig_freq, const VectorBase &wave, + BaseFloat new_freq, Vector *new_wave) { + ResampleWaveform(orig_freq, wave, new_freq, new_wave); +} + + +/// @} End of "addtogroup feat" +} // namespace kaldi +#endif // KALDI_FEAT_RESAMPLE_H_ diff --git a/speechx/speechx/kaldi/feat/signal.cc b/speechx/speechx/kaldi/feat/signal.cc new file mode 100644 index 00000000..a206d399 --- /dev/null +++ b/speechx/speechx/kaldi/feat/signal.cc @@ -0,0 +1,129 @@ +// feat/signal.cc + +// Copyright 2015 Tom Ko + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "feat/signal.h" + +namespace kaldi { + +void ElementwiseProductOfFft(const Vector &a, Vector *b) { + int32 num_fft_bins = a.Dim() / 2; + for (int32 i = 0; i < num_fft_bins; i++) { + // do complex multiplication + ComplexMul(a(2*i), a(2*i + 1), &((*b)(2*i)), &((*b)(2*i + 1))); + } +} + +void ConvolveSignals(const Vector &filter, Vector *signal) { + int32 signal_length = signal->Dim(); + int32 filter_length = filter.Dim(); + int32 output_length = signal_length + filter_length - 1; + Vector signal_padded(output_length); + signal_padded.SetZero(); + for (int32 i = 0; i < signal_length; i++) { + for (int32 j = 0; j < filter_length; j++) { + signal_padded(i + j) += (*signal)(i) * filter(j); + } + } + signal->Resize(output_length); + signal->CopyFromVec(signal_padded); +} + + +void FFTbasedConvolveSignals(const Vector &filter, Vector *signal) { + int32 signal_length = signal->Dim(); + int32 filter_length = filter.Dim(); + int32 output_length = signal_length + filter_length - 1; + + int32 fft_length = RoundUpToNearestPowerOfTwo(output_length); + KALDI_VLOG(1) << "fft_length for full signal convolution is " << fft_length; + + SplitRadixRealFft srfft(fft_length); + + Vector filter_padded(fft_length); + filter_padded.Range(0, filter_length).CopyFromVec(filter); + srfft.Compute(filter_padded.Data(), true); + + Vector signal_padded(fft_length); + signal_padded.Range(0, signal_length).CopyFromVec(*signal); + srfft.Compute(signal_padded.Data(), true); + + ElementwiseProductOfFft(filter_padded, &signal_padded); + + srfft.Compute(signal_padded.Data(), false); + signal_padded.Scale(1.0 / fft_length); + + signal->Resize(output_length); + signal->CopyFromVec(signal_padded.Range(0, output_length)); +} + +void FFTbasedBlockConvolveSignals(const Vector &filter, Vector *signal) { + int32 signal_length = signal->Dim(); + int32 filter_length = filter.Dim(); + int32 output_length = signal_length + filter_length - 1; + signal->Resize(output_length, kCopyData); + + KALDI_VLOG(1) << "Length of the filter is " << filter_length; + + int32 fft_length = RoundUpToNearestPowerOfTwo(4 * filter_length); + KALDI_VLOG(1) << "Best FFT length is " << fft_length; + + int32 block_length = fft_length - filter_length + 1; + KALDI_VLOG(1) << "Block size is " << block_length; + SplitRadixRealFft srfft(fft_length); + + Vector filter_padded(fft_length); + filter_padded.Range(0, filter_length).CopyFromVec(filter); + srfft.Compute(filter_padded.Data(), true); + + Vector temp_pad(filter_length - 1); + temp_pad.SetZero(); + Vector signal_block_padded(fft_length); + + for (int32 po = 0; po < output_length; po += block_length) { + // get a block of the signal + int32 process_length = std::min(block_length, output_length - po); + signal_block_padded.SetZero(); + signal_block_padded.Range(0, process_length).CopyFromVec(signal->Range(po, process_length)); + + srfft.Compute(signal_block_padded.Data(), true); + + ElementwiseProductOfFft(filter_padded, &signal_block_padded); + + srfft.Compute(signal_block_padded.Data(), false); + signal_block_padded.Scale(1.0 / fft_length); + + // combine the block + if (po + block_length < output_length) { // current block is not the last block + signal->Range(po, block_length).CopyFromVec(signal_block_padded.Range(0, block_length)); + signal->Range(po, filter_length - 1).AddVec(1.0, temp_pad); + temp_pad.CopyFromVec(signal_block_padded.Range(block_length, filter_length - 1)); + } else { + signal->Range(po, output_length - po).CopyFromVec( + signal_block_padded.Range(0, output_length - po)); + if (filter_length - 1 < output_length - po) + signal->Range(po, filter_length - 1).AddVec(1.0, temp_pad); + else + signal->Range(po, output_length - po).AddVec(1.0, temp_pad.Range(0, output_length - po)); + } + } +} +} + diff --git a/speechx/speechx/kaldi/feat/signal.h b/speechx/speechx/kaldi/feat/signal.h new file mode 100644 index 00000000..c6c3eb50 --- /dev/null +++ b/speechx/speechx/kaldi/feat/signal.h @@ -0,0 +1,58 @@ +// feat/signal.h + +// Copyright 2015 Tom Ko + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_SIGNAL_H_ +#define KALDI_FEAT_SIGNAL_H_ + +#include "base/kaldi-common.h" +#include "util/common-utils.h" + +namespace kaldi { + +/* + The following three functions are having the same functionality but + different implementations so as the efficiency. After the convolution, + the length of the signal will be extended to (original signal length + + filter length - 1). +*/ + +/* + This function implements a simple non-FFT-based convolution of two signals. + It is suggested to use the FFT-based convolution function which is more + efficient. +*/ +void ConvolveSignals(const Vector &filter, Vector *signal); + +/* + This function implements FFT-based convolution of two signals. + However this should be an inefficient version of BlockConvolveSignals() + as it processes the entire signal with a single FFT. +*/ +void FFTbasedConvolveSignals(const Vector &filter, Vector *signal); + +/* + This function implements FFT-based block convolution of two signals using + overlap-add method. This is an efficient way to evaluate the discrete + convolution of a long signal with a finite impulse response filter. +*/ +void FFTbasedBlockConvolveSignals(const Vector &filter, Vector *signal); + +} // namespace kaldi + +#endif // KALDI_FEAT_SIGNAL_H_ diff --git a/speechx/speechx/kaldi/feat/wave-reader.cc b/speechx/speechx/kaldi/feat/wave-reader.cc new file mode 100644 index 00000000..f8259a3a --- /dev/null +++ b/speechx/speechx/kaldi/feat/wave-reader.cc @@ -0,0 +1,387 @@ +// feat/wave-reader.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek +// 2013 Florent Masson +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include + +#include "feat/wave-reader.h" +#include "base/kaldi-error.h" +#include "base/kaldi-utils.h" + +namespace kaldi { + +// A utility class for reading wave header. +struct WaveHeaderReadGofer { + std::istream &is; + bool swap; + char tag[5]; + + WaveHeaderReadGofer(std::istream &is) : is(is), swap(false) { + memset(tag, '\0', sizeof tag); + } + + void Expect4ByteTag(const char *expected) { + is.read(tag, 4); + if (is.fail()) + KALDI_ERR << "WaveData: expected " << expected + << ", failed to read anything"; + if (strcmp(tag, expected)) + KALDI_ERR << "WaveData: expected " << expected << ", got " << tag; + } + + void Read4ByteTag() { + is.read(tag, 4); + if (is.fail()) + KALDI_ERR << "WaveData: expected 4-byte chunk-name, got read error"; + } + + uint32 ReadUint32() { + union { + char result[4]; + uint32 ans; + } u; + is.read(u.result, 4); + if (swap) + KALDI_SWAP4(u.result); + if (is.fail()) + KALDI_ERR << "WaveData: unexpected end of file or read error"; + return u.ans; + } + + uint16 ReadUint16() { + union { + char result[2]; + int16 ans; + } u; + is.read(u.result, 2); + if (swap) + KALDI_SWAP2(u.result); + if (is.fail()) + KALDI_ERR << "WaveData: unexpected end of file or read error"; + return u.ans; + } +}; + +static void WriteUint32(std::ostream &os, int32 i) { + union { + char buf[4]; + int i; + } u; + u.i = i; +#ifdef __BIG_ENDIAN__ + KALDI_SWAP4(u.buf); +#endif + os.write(u.buf, 4); + if (os.fail()) + KALDI_ERR << "WaveData: error writing to stream."; +} + +static void WriteUint16(std::ostream &os, int16 i) { + union { + char buf[2]; + int16 i; + } u; + u.i = i; +#ifdef __BIG_ENDIAN__ + KALDI_SWAP2(u.buf); +#endif + os.write(u.buf, 2); + if (os.fail()) + KALDI_ERR << "WaveData: error writing to stream."; +} + +void WaveInfo::Read(std::istream &is) { + WaveHeaderReadGofer reader(is); + reader.Read4ByteTag(); + if (strcmp(reader.tag, "RIFF") == 0) + reverse_bytes_ = false; + else if (strcmp(reader.tag, "RIFX") == 0) + reverse_bytes_ = true; + else + KALDI_ERR << "WaveData: expected RIFF or RIFX, got " << reader.tag; + +#ifdef __BIG_ENDIAN__ + reverse_bytes_ = !reverse_bytes_; +#endif + reader.swap = reverse_bytes_; + + uint32 riff_chunk_size = reader.ReadUint32(); + reader.Expect4ByteTag("WAVE"); + + uint32 riff_chunk_read = 0; + riff_chunk_read += 4; // WAVE included in riff_chunk_size. + + // Possibly skip any RIFF tags between 'WAVE' and 'fmt '. + // Apple devices produce a filler tag 'JUNK' for memory alignment. + reader.Read4ByteTag(); + riff_chunk_read += 4; + while (strcmp(reader.tag,"fmt ") != 0) { + uint32 filler_size = reader.ReadUint32(); + riff_chunk_read += 4; + for (uint32 i = 0; i < filler_size; i++) { + is.get(); // read 1 byte, + } + riff_chunk_read += filler_size; + // get next RIFF tag, + reader.Read4ByteTag(); + riff_chunk_read += 4; + } + + KALDI_ASSERT(strcmp(reader.tag,"fmt ") == 0); + uint32 subchunk1_size = reader.ReadUint32(); + uint16 audio_format = reader.ReadUint16(); + num_channels_ = reader.ReadUint16(); + uint32 sample_rate = reader.ReadUint32(), + byte_rate = reader.ReadUint32(), + block_align = reader.ReadUint16(), + bits_per_sample = reader.ReadUint16(); + samp_freq_ = static_cast(sample_rate); + + uint32 fmt_chunk_read = 16; + if (audio_format == 1) { + if (subchunk1_size < 16) { + KALDI_ERR << "WaveData: expect PCM format data to have fmt chunk " + << "of at least size 16."; + } + } else if (audio_format == 0xFFFE) { // WAVE_FORMAT_EXTENSIBLE + uint16 extra_size = reader.ReadUint16(); + if (subchunk1_size < 40 || extra_size < 22) { + KALDI_ERR << "WaveData: malformed WAVE_FORMAT_EXTENSIBLE format data."; + } + reader.ReadUint16(); // Unused for PCM. + reader.ReadUint32(); // Channel map: we do not care. + uint32 guid1 = reader.ReadUint32(), + guid2 = reader.ReadUint32(), + guid3 = reader.ReadUint32(), + guid4 = reader.ReadUint32(); + fmt_chunk_read = 40; + + // Support only KSDATAFORMAT_SUBTYPE_PCM for now. Interesting formats: + // ("00000001-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_PCM) + // ("00000003-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_IEEE_FLOAT) + // ("00000006-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_ALAW) + // ("00000007-0000-0010-8000-00aa00389b71", KSDATAFORMAT_SUBTYPE_MULAW) + if (guid1 != 0x00000001 || guid2 != 0x00100000 || + guid3 != 0xAA000080 || guid4 != 0x719B3800) { + KALDI_ERR << "WaveData: unsupported WAVE_FORMAT_EXTENSIBLE format."; + } + } else { + KALDI_ERR << "WaveData: can read only PCM data, format id in file is: " + << audio_format; + } + + for (uint32 i = fmt_chunk_read; i < subchunk1_size; ++i) + is.get(); // use up extra data. + + if (num_channels_ == 0) + KALDI_ERR << "WaveData: no channels present"; + if (bits_per_sample != 16) + KALDI_ERR << "WaveData: unsupported bits_per_sample = " << bits_per_sample; + if (byte_rate != sample_rate * bits_per_sample/8 * num_channels_) + KALDI_ERR << "Unexpected byte rate " << byte_rate << " vs. " + << sample_rate << " * " << (bits_per_sample/8) + << " * " << num_channels_; + if (block_align != num_channels_ * bits_per_sample/8) + KALDI_ERR << "Unexpected block_align: " << block_align << " vs. " + << num_channels_ << " * " << (bits_per_sample/8); + + riff_chunk_read += 4 + subchunk1_size; + // size of what we just read, 4 for subchunk1_size + subchunk1_size itself. + + // We support an optional "fact" chunk (which is useless but which + // we encountered), and then a single "data" chunk. + + reader.Read4ByteTag(); + riff_chunk_read += 4; + + // Skip any subchunks between "fmt" and "data". Usually there will + // be a single "fact" subchunk, but on Windows there can also be a + // "list" subchunk. + while (strcmp(reader.tag, "data") != 0) { + // We will just ignore the data in these chunks. + uint32 chunk_sz = reader.ReadUint32(); + if (chunk_sz != 4 && strcmp(reader.tag, "fact") == 0) + KALDI_WARN << "Expected fact chunk to be 4 bytes long."; + for (uint32 i = 0; i < chunk_sz; i++) + is.get(); + riff_chunk_read += 4 + chunk_sz; // for chunk_sz (4) + chunk contents (chunk-sz) + + // Now read the next chunk name. + reader.Read4ByteTag(); + riff_chunk_read += 4; + } + + KALDI_ASSERT(strcmp(reader.tag, "data") == 0); + uint32 data_chunk_size = reader.ReadUint32(); + riff_chunk_read += 4; + + // Figure out if the file is going to be read to the end. Values as + // observed in the wild: + bool is_stream_mode = + riff_chunk_size == 0 + || riff_chunk_size == 0xFFFFFFFF + || data_chunk_size == 0 + || data_chunk_size == 0xFFFFFFFF + || data_chunk_size == 0x7FFFF000; // This value is used by SoX. + + if (is_stream_mode) + KALDI_VLOG(1) << "Read in RIFF chunk size: " << riff_chunk_size + << ", data chunk size: " << data_chunk_size + << ". Assume 'stream mode' (reading data to EOF)."; + + if (!is_stream_mode + && std::abs(static_cast(riff_chunk_read) + + static_cast(data_chunk_size) - + static_cast(riff_chunk_size)) > 1) { + // We allow the size to be off by one without warning, because there is a + // weirdness in the format of RIFF files that means that the input may + // sometimes be padded with 1 unused byte to make the total size even. + KALDI_WARN << "Expected " << riff_chunk_size << " bytes in RIFF chunk, but " + << "after first data block there will be " << riff_chunk_read + << " + " << data_chunk_size << " bytes " + << "(we do not support reading multiple data chunks)."; + } + + if (is_stream_mode) + samp_count_ = -1; + else + samp_count_ = data_chunk_size / block_align; +} + +void WaveData::Read(std::istream &is) { + const uint32 kBlockSize = 1024 * 1024; + + WaveInfo header; + header.Read(is); + + data_.Resize(0, 0); // clear the data. + samp_freq_ = header.SampFreq(); + + std::vector buffer; + uint32 bytes_to_go = header.IsStreamed() ? kBlockSize : header.DataBytes(); + + // Once in a while header.DataBytes() will report an insane value; + // read the file to the end + while (is && bytes_to_go > 0) { + uint32 block_bytes = std::min(bytes_to_go, kBlockSize); + uint32 offset = buffer.size(); + buffer.resize(offset + block_bytes); + is.read(&buffer[offset], block_bytes); + uint32 bytes_read = is.gcount(); + buffer.resize(offset + bytes_read); + if (!header.IsStreamed()) + bytes_to_go -= bytes_read; + } + + if (is.bad()) + KALDI_ERR << "WaveData: file read error"; + + if (buffer.size() == 0) + KALDI_ERR << "WaveData: empty file (no data)"; + + if (!header.IsStreamed() && buffer.size() < header.DataBytes()) { + KALDI_WARN << "Expected " << header.DataBytes() << " bytes of wave data, " + << "but read only " << buffer.size() << " bytes. " + << "Truncated file?"; + } + + uint16 *data_ptr = reinterpret_cast(&buffer[0]); + + // The matrix is arranged row per channel, column per sample. + data_.Resize(header.NumChannels(), + buffer.size() / header.BlockAlign()); + for (uint32 i = 0; i < data_.NumCols(); ++i) { + for (uint32 j = 0; j < data_.NumRows(); ++j) { + int16 k = *data_ptr++; + if (header.ReverseBytes()) + KALDI_SWAP2(k); + data_(j, i) = k; + } + } +} + + +// Write 16-bit PCM. + +// note: the WAVE chunk contains 2 subchunks. +// +// subchunk2size = data.NumRows() * data.NumCols() * 2. + + +void WaveData::Write(std::ostream &os) const { + os << "RIFF"; + if (data_.NumRows() == 0) + KALDI_ERR << "Error: attempting to write empty WAVE file"; + + int32 num_chan = data_.NumRows(), + num_samp = data_.NumCols(), + bytes_per_samp = 2; + + int32 subchunk2size = (num_chan * num_samp * bytes_per_samp); + int32 chunk_size = 36 + subchunk2size; + WriteUint32(os, chunk_size); + os << "WAVE"; + os << "fmt "; + WriteUint32(os, 16); + WriteUint16(os, 1); + WriteUint16(os, num_chan); + KALDI_ASSERT(samp_freq_ > 0); + WriteUint32(os, static_cast(samp_freq_)); + WriteUint32(os, static_cast(samp_freq_) * num_chan * bytes_per_samp); + WriteUint16(os, num_chan * bytes_per_samp); + WriteUint16(os, 8 * bytes_per_samp); + os << "data"; + WriteUint32(os, subchunk2size); + + const BaseFloat *data_ptr = data_.Data(); + int32 stride = data_.Stride(); + + int num_clipped = 0; + for (int32 i = 0; i < num_samp; i++) { + for (int32 j = 0; j < num_chan; j++) { + int32 elem = static_cast(trunc(data_ptr[j * stride + i])); + int16 elem_16 = static_cast(elem); + if (elem < std::numeric_limits::min()) { + elem_16 = std::numeric_limits::min(); + ++num_clipped; + } else if (elem > std::numeric_limits::max()) { + elem_16 = std::numeric_limits::max(); + ++num_clipped; + } +#ifdef __BIG_ENDIAN__ + KALDI_SWAP2(elem_16); +#endif + os.write(reinterpret_cast(&elem_16), 2); + } + } + if (os.fail()) + KALDI_ERR << "Error writing wave data to stream."; + if (num_clipped > 0) + KALDI_WARN << "WARNING: clipped " << num_clipped + << " samples out of total " << num_chan * num_samp + << ". Reduce volume?"; +} + + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/feat/wave-reader.h b/speechx/speechx/kaldi/feat/wave-reader.h new file mode 100644 index 00000000..dae74139 --- /dev/null +++ b/speechx/speechx/kaldi/feat/wave-reader.h @@ -0,0 +1,248 @@ +// feat/wave-reader.h + +// Copyright 2009-2011 Karel Vesely; Microsoft Corporation +// 2013 Florent Masson +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +/* +// THE WAVE FORMAT IS SPECIFIED IN: +// https:// ccrma.stanford.edu/courses/422/projects/WaveFormat/ +// +// +// +// RIFF +// | +// WAVE +// | \ \ \ +// fmt_ data ... data +// +// +// Riff is a general container, which usually contains one WAVE chunk +// each WAVE chunk has header sub-chunk 'fmt_' +// and one or more data sub-chunks 'data' +// +// [Note from Dan: to say that the wave format was ever "specified" anywhere is +// not quite right. The guy who invented the wave format attempted to create +// a formal specification but it did not completely make sense. And there +// doesn't seem to be a consensus on what makes a valid wave file, +// particularly where the accuracy of header information is concerned.] +*/ + + +#ifndef KALDI_FEAT_WAVE_READER_H_ +#define KALDI_FEAT_WAVE_READER_H_ + +#include + +#include "base/kaldi-types.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + + +namespace kaldi { + +/// For historical reasons, we scale waveforms to the range +/// (2^15-1)*[-1, 1], not the usual default DSP range [-1, 1]. +const BaseFloat kWaveSampleMax = 32768.0; + +/// This class reads and hold wave file header information. +class WaveInfo { + public: + WaveInfo() : samp_freq_(0), samp_count_(0), + num_channels_(0), reverse_bytes_(0) {} + + /// Is stream size unknown? Duration and SampleCount not valid if true. + bool IsStreamed() const { return samp_count_ < 0; } + + /// Sample frequency, Hz. + BaseFloat SampFreq() const { return samp_freq_; } + + /// Number of samples in stream. Invalid if IsStreamed() is true. + uint32 SampleCount() const { return samp_count_; } + + /// Approximate duration, seconds. Invalid if IsStreamed() is true. + BaseFloat Duration() const { return samp_count_ / samp_freq_; } + + /// Number of channels, 1 to 16. + int32 NumChannels() const { return num_channels_; } + + /// Bytes per sample. + size_t BlockAlign() const { return 2 * num_channels_; } + + /// Wave data bytes. Invalid if IsStreamed() is true. + size_t DataBytes() const { return samp_count_ * BlockAlign(); } + + /// Is data file byte order different from machine byte order? + bool ReverseBytes() const { return reverse_bytes_; } + + /// 'is' should be opened in binary mode. Read() will throw on error. + /// On success 'is' will be positioned at the beginning of wave data. + void Read(std::istream &is); + + private: + BaseFloat samp_freq_; + int32 samp_count_; // 0 if empty, -1 if undefined length. + uint8 num_channels_; + bool reverse_bytes_; // File endianness differs from host. +}; + +/// This class's purpose is to read in Wave files. +class WaveData { + public: + WaveData(BaseFloat samp_freq, const MatrixBase &data) + : data_(data), samp_freq_(samp_freq) {} + + WaveData() : samp_freq_(0.0) {} + + /// Read() will throw on error. It's valid to call Read() more than once-- + /// in this case it will destroy what was there before. + /// "is" should be opened in binary mode. + void Read(std::istream &is); + + /// Write() will throw on error. os should be opened in binary mode. + void Write(std::ostream &os) const; + + // This function returns the wave data-- it's in a matrix + // because there may be multiple channels. In the normal case + // there's just one channel so Data() will have one row. + const Matrix &Data() const { return data_; } + + BaseFloat SampFreq() const { return samp_freq_; } + + // Returns the duration in seconds + BaseFloat Duration() const { return data_.NumCols() / samp_freq_; } + + void CopyFrom(const WaveData &other) { + samp_freq_ = other.samp_freq_; + data_.CopyFromMat(other.data_); + } + + void Clear() { + data_.Resize(0, 0); + samp_freq_ = 0.0; + } + + void Swap(WaveData *other) { + data_.Swap(&(other->data_)); + std::swap(samp_freq_, other->samp_freq_); + } + + private: + static const uint32 kBlockSize = 1024 * 1024; // Use 1M bytes. + Matrix data_; + BaseFloat samp_freq_; +}; + + +// Holder class for .wav files that enables us to read (but not write) .wav +// files. c.f. util/kaldi-holder.h we don't use the KaldiObjectHolder template +// because we don't want to check for the \0B binary header. We could have faked +// it by pretending to read in the wave data in text mode after failing to find +// the \0B header, but that would have been a little ugly. +class WaveHolder { + public: + typedef WaveData T; + + static bool Write(std::ostream &os, bool binary, const T &t) { + // We don't write the binary-mode header here [always binary]. + if (!binary) + KALDI_ERR << "Wave data can only be written in binary mode."; + try { + t.Write(os); // throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveHolder object (writing). " + << e.what(); + return false; // write failure. + } + } + void Copy(const T &t) { t_.CopyFrom(t); } + + static bool IsReadInBinary() { return true; } + + void Clear() { t_.Clear(); } + + T &Value() { return t_; } + + WaveHolder &operator = (const WaveHolder &other) { + t_.CopyFrom(other.t_); + return *this; + } + WaveHolder(const WaveHolder &other): t_(other.t_) {} + + WaveHolder() {} + + bool Read(std::istream &is) { + // We don't look for the binary-mode header here [always binary] + try { + t_.Read(is); // Throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveHolder::Read(). " << e.what(); + return false; + } + } + + void Swap(WaveHolder *other) { + t_.Swap(&(other->t_)); + } + + bool ExtractRange(const WaveHolder &other, const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + T t_; +}; + +// This is like WaveHolder but when you just want the metadata- +// it leaves the actual data undefined, it doesn't read it. +class WaveInfoHolder { + public: + typedef WaveInfo T; + + void Clear() { info_ = WaveInfo(); } + void Swap(WaveInfoHolder *other) { std::swap(info_, other->info_); } + T &Value() { return info_; } + static bool IsReadInBinary() { return true; } + + bool Read(std::istream &is) { + try { + info_.Read(is); // Throws exception on failure. + return true; + } catch (const std::exception &e) { + KALDI_WARN << "Exception caught in WaveInfoHolder::Read(). " << e.what(); + return false; + } + } + + bool ExtractRange(const WaveInfoHolder &other, const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + WaveInfo info_; +}; + + +} // namespace kaldi + +#endif // KALDI_FEAT_WAVE_READER_H_ diff --git a/speechx/speechx/kaldi/matrix/BUILD b/speechx/speechx/kaldi/matrix/BUILD new file mode 100644 index 00000000..cefac6fc --- /dev/null +++ b/speechx/speechx/kaldi/matrix/BUILD @@ -0,0 +1,39 @@ +# Copyright (c) 2020 PeachLab. All Rights Reserved. +# Author : goat.zhou@qq.com (Yang Zhou) + +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = 'kaldi-matrix', + srcs = [ + 'compressed-matrix.cc', + 'kaldi-matrix.cc', + 'kaldi-vector.cc', + 'matrix-functions.cc', + 'optimization.cc', + 'packed-matrix.cc', + 'qr.cc', + 'sparse-matrix.cc', + 'sp-matrix.cc', + 'srfft.cc', + 'tp-matrix.cc', + ], + hdrs = glob(["*.h"]), + deps = [ + '//base:kaldi-base', + '//common/third_party/openblas:openblas', + ], + linkopts=['-lgfortran'], +) + +cc_binary( + name = 'matrix-lib-test', + srcs = [ + 'matrix-lib-test.cc', + ], + deps = [ + ':kaldi-matrix', + '//util:kaldi-util', + ], +) + diff --git a/speechx/speechx/kaldi/matrix/CMakeLists.txt b/speechx/speechx/kaldi/matrix/CMakeLists.txt new file mode 100644 index 00000000..a4dbde2e --- /dev/null +++ b/speechx/speechx/kaldi/matrix/CMakeLists.txt @@ -0,0 +1,16 @@ + +add_library(kaldi-matrix +compressed-matrix.cc +kaldi-matrix.cc +kaldi-vector.cc +matrix-functions.cc +optimization.cc +packed-matrix.cc +qr.cc +sparse-matrix.cc +sp-matrix.cc +srfft.cc +tp-matrix.cc +) + +target_link_libraries(kaldi-matrix gfortran kaldi-base libopenblas.a) diff --git a/speechx/speechx/kaldi/matrix/cblas-wrappers.h b/speechx/speechx/kaldi/matrix/cblas-wrappers.h new file mode 100644 index 00000000..f869ab7e --- /dev/null +++ b/speechx/speechx/kaldi/matrix/cblas-wrappers.h @@ -0,0 +1,491 @@ +// matrix/cblas-wrappers.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey); +// Haihua Xu; Wei Shi + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_CBLAS_WRAPPERS_H_ +#define KALDI_MATRIX_CBLAS_WRAPPERS_H_ 1 + + +#include +#include "matrix/sp-matrix.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/matrix-functions.h" +#include "matrix/kaldi-blas.h" + +// Do not include this file directly. It is to be included +// by .cc files in this directory. + +namespace kaldi { + + +inline void cblas_Xcopy(const int N, const float *X, const int incX, float *Y, + const int incY) { + cblas_scopy(N, X, incX, Y, incY); +} + +inline void cblas_Xcopy(const int N, const double *X, const int incX, double *Y, + const int incY) { + cblas_dcopy(N, X, incX, Y, incY); +} + + +inline float cblas_Xasum(const int N, const float *X, const int incX) { + return cblas_sasum(N, X, incX); +} + +inline double cblas_Xasum(const int N, const double *X, const int incX) { + return cblas_dasum(N, X, incX); +} + +inline void cblas_Xrot(const int N, float *X, const int incX, float *Y, + const int incY, const float c, const float s) { + cblas_srot(N, X, incX, Y, incY, c, s); +} +inline void cblas_Xrot(const int N, double *X, const int incX, double *Y, + const int incY, const double c, const double s) { + cblas_drot(N, X, incX, Y, incY, c, s); +} +inline float cblas_Xdot(const int N, const float *const X, + const int incX, const float *const Y, + const int incY) { + return cblas_sdot(N, X, incX, Y, incY); +} +inline double cblas_Xdot(const int N, const double *const X, + const int incX, const double *const Y, + const int incY) { + return cblas_ddot(N, X, incX, Y, incY); +} +inline void cblas_Xaxpy(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY) { + cblas_saxpy(N, alpha, X, incX, Y, incY); +} +inline void cblas_Xaxpy(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY) { + cblas_daxpy(N, alpha, X, incX, Y, incY); +} +inline void cblas_Xscal(const int N, const float alpha, float *data, + const int inc) { + cblas_sscal(N, alpha, data, inc); +} +inline void cblas_Xscal(const int N, const double alpha, double *data, + const int inc) { + cblas_dscal(N, alpha, data, inc); +} +inline void cblas_Xspmv(const float alpha, const int num_rows, const float *Mdata, + const float *v, const int v_inc, + const float beta, float *y, const int y_inc) { + cblas_sspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc); +} +inline void cblas_Xspmv(const double alpha, const int num_rows, const double *Mdata, + const double *v, const int v_inc, + const double beta, double *y, const int y_inc) { + cblas_dspmv(CblasRowMajor, CblasLower, num_rows, alpha, Mdata, v, v_inc, beta, y, y_inc); +} +inline void cblas_Xtpmv(MatrixTransposeType trans, const float *Mdata, + const int num_rows, float *y, const int y_inc) { + cblas_stpmv(CblasRowMajor, CblasLower, static_cast(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} +inline void cblas_Xtpmv(MatrixTransposeType trans, const double *Mdata, + const int num_rows, double *y, const int y_inc) { + cblas_dtpmv(CblasRowMajor, CblasLower, static_cast(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} + + +inline void cblas_Xtpsv(MatrixTransposeType trans, const float *Mdata, + const int num_rows, float *y, const int y_inc) { + cblas_stpsv(CblasRowMajor, CblasLower, static_cast(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} +inline void cblas_Xtpsv(MatrixTransposeType trans, const double *Mdata, + const int num_rows, double *y, const int y_inc) { + cblas_dtpsv(CblasRowMajor, CblasLower, static_cast(trans), + CblasNonUnit, num_rows, Mdata, y, y_inc); +} + +// x = alpha * M * y + beta * x +inline void cblas_Xspmv(MatrixIndexT dim, float alpha, const float *Mdata, + const float *ydata, MatrixIndexT ystride, + float beta, float *xdata, MatrixIndexT xstride) { + cblas_sspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata, + ydata, ystride, beta, xdata, xstride); +} +inline void cblas_Xspmv(MatrixIndexT dim, double alpha, const double *Mdata, + const double *ydata, MatrixIndexT ystride, + double beta, double *xdata, MatrixIndexT xstride) { + cblas_dspmv(CblasRowMajor, CblasLower, dim, alpha, Mdata, + ydata, ystride, beta, xdata, xstride); +} + +// Implements A += alpha * (x y' + y x'); A is symmetric matrix. +inline void cblas_Xspr2(MatrixIndexT dim, float alpha, const float *Xdata, + MatrixIndexT incX, const float *Ydata, MatrixIndexT incY, + float *Adata) { + cblas_sspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata, + incX, Ydata, incY, Adata); +} +inline void cblas_Xspr2(MatrixIndexT dim, double alpha, const double *Xdata, + MatrixIndexT incX, const double *Ydata, MatrixIndexT incY, + double *Adata) { + cblas_dspr2(CblasRowMajor, CblasLower, dim, alpha, Xdata, + incX, Ydata, incY, Adata); +} + +// Implements A += alpha * (x x'); A is symmetric matrix. +inline void cblas_Xspr(MatrixIndexT dim, float alpha, const float *Xdata, + MatrixIndexT incX, float *Adata) { + cblas_sspr(CblasRowMajor, CblasLower, dim, alpha, Xdata, incX, Adata); +} +inline void cblas_Xspr(MatrixIndexT dim, double alpha, const double *Xdata, + MatrixIndexT incX, double *Adata) { + cblas_dspr(CblasRowMajor, CblasLower, dim, alpha, Xdata, incX, Adata); +} + +// sgemv,dgemv: y = alpha M x + beta y. +inline void cblas_Xgemv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, float alpha, const float *Mdata, + MatrixIndexT stride, const float *xdata, + MatrixIndexT incX, float beta, float *ydata, MatrixIndexT incY) { + cblas_sgemv(CblasRowMajor, static_cast(trans), num_rows, + num_cols, alpha, Mdata, stride, xdata, incX, beta, ydata, incY); +} +inline void cblas_Xgemv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, double alpha, const double *Mdata, + MatrixIndexT stride, const double *xdata, + MatrixIndexT incX, double beta, double *ydata, MatrixIndexT incY) { + cblas_dgemv(CblasRowMajor, static_cast(trans), num_rows, + num_cols, alpha, Mdata, stride, xdata, incX, beta, ydata, incY); +} + +// sgbmv, dgmmv: y = alpha M x + + beta * y. +inline void cblas_Xgbmv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, MatrixIndexT num_below, + MatrixIndexT num_above, float alpha, const float *Mdata, + MatrixIndexT stride, const float *xdata, + MatrixIndexT incX, float beta, float *ydata, MatrixIndexT incY) { + cblas_sgbmv(CblasRowMajor, static_cast(trans), num_rows, + num_cols, num_below, num_above, alpha, Mdata, stride, xdata, + incX, beta, ydata, incY); +} +inline void cblas_Xgbmv(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, MatrixIndexT num_below, + MatrixIndexT num_above, double alpha, const double *Mdata, + MatrixIndexT stride, const double *xdata, + MatrixIndexT incX, double beta, double *ydata, MatrixIndexT incY) { + cblas_dgbmv(CblasRowMajor, static_cast(trans), num_rows, + num_cols, num_below, num_above, alpha, Mdata, stride, xdata, + incX, beta, ydata, incY); +} + + +template +inline void Xgemv_sparsevec(MatrixTransposeType trans, MatrixIndexT num_rows, + MatrixIndexT num_cols, Real alpha, const Real *Mdata, + MatrixIndexT stride, const Real *xdata, + MatrixIndexT incX, Real beta, Real *ydata, + MatrixIndexT incY) { + if (trans == kNoTrans) { + if (beta != 1.0) cblas_Xscal(num_rows, beta, ydata, incY); + for (MatrixIndexT i = 0; i < num_cols; i++) { + Real x_i = xdata[i * incX]; + if (x_i == 0.0) continue; + // Add to ydata, the i'th column of M, times alpha * x_i + cblas_Xaxpy(num_rows, x_i * alpha, Mdata + i, stride, ydata, incY); + } + } else { + if (beta != 1.0) cblas_Xscal(num_cols, beta, ydata, incY); + for (MatrixIndexT i = 0; i < num_rows; i++) { + Real x_i = xdata[i * incX]; + if (x_i == 0.0) continue; + // Add to ydata, the i'th row of M, times alpha * x_i + cblas_Xaxpy(num_cols, x_i * alpha, + Mdata + (i * stride), 1, ydata, incY); + } + } +} + +inline void cblas_Xgemm(const float alpha, + MatrixTransposeType transA, + const float *Adata, + MatrixIndexT a_num_rows, MatrixIndexT a_num_cols, MatrixIndexT a_stride, + MatrixTransposeType transB, + const float *Bdata, MatrixIndexT b_stride, + const float beta, + float *Mdata, + MatrixIndexT num_rows, MatrixIndexT num_cols,MatrixIndexT stride) { + cblas_sgemm(CblasRowMajor, static_cast(transA), + static_cast(transB), + num_rows, num_cols, transA == kNoTrans ? a_num_cols : a_num_rows, + alpha, Adata, a_stride, Bdata, b_stride, + beta, Mdata, stride); +} +inline void cblas_Xgemm(const double alpha, + MatrixTransposeType transA, + const double *Adata, + MatrixIndexT a_num_rows, MatrixIndexT a_num_cols, MatrixIndexT a_stride, + MatrixTransposeType transB, + const double *Bdata, MatrixIndexT b_stride, + const double beta, + double *Mdata, + MatrixIndexT num_rows, MatrixIndexT num_cols,MatrixIndexT stride) { + cblas_dgemm(CblasRowMajor, static_cast(transA), + static_cast(transB), + num_rows, num_cols, transA == kNoTrans ? a_num_cols : a_num_rows, + alpha, Adata, a_stride, Bdata, b_stride, + beta, Mdata, stride); +} + + +inline void cblas_Xsymm(const float alpha, + MatrixIndexT sz, + const float *Adata,MatrixIndexT a_stride, + const float *Bdata,MatrixIndexT b_stride, + const float beta, + float *Mdata, MatrixIndexT stride) { + cblas_ssymm(CblasRowMajor, CblasLeft, CblasLower, sz, sz, alpha, Adata, + a_stride, Bdata, b_stride, beta, Mdata, stride); +} +inline void cblas_Xsymm(const double alpha, + MatrixIndexT sz, + const double *Adata,MatrixIndexT a_stride, + const double *Bdata,MatrixIndexT b_stride, + const double beta, + double *Mdata, MatrixIndexT stride) { + cblas_dsymm(CblasRowMajor, CblasLeft, CblasLower, sz, sz, alpha, Adata, + a_stride, Bdata, b_stride, beta, Mdata, stride); +} +// ger: M += alpha x y^T. +inline void cblas_Xger(MatrixIndexT num_rows, MatrixIndexT num_cols, float alpha, + const float *xdata, MatrixIndexT incX, const float *ydata, + MatrixIndexT incY, float *Mdata, MatrixIndexT stride) { + cblas_sger(CblasRowMajor, num_rows, num_cols, alpha, xdata, 1, ydata, 1, + Mdata, stride); +} +inline void cblas_Xger(MatrixIndexT num_rows, MatrixIndexT num_cols, double alpha, + const double *xdata, MatrixIndexT incX, const double *ydata, + MatrixIndexT incY, double *Mdata, MatrixIndexT stride) { + cblas_dger(CblasRowMajor, num_rows, num_cols, alpha, xdata, 1, ydata, 1, + Mdata, stride); +} + +// syrk: symmetric rank-k update. +// if trans==kNoTrans, then C = alpha A A^T + beta C +// else C = alpha A^T A + beta C. +// note: dim_c is dim(C), other_dim_a is the "other" dimension of A, i.e. +// num-cols(A) if kNoTrans, or num-rows(A) if kTrans. +// We only need the row-major and lower-triangular option of this, and this +// is hard-coded. +inline void cblas_Xsyrk ( + const MatrixTransposeType trans, const MatrixIndexT dim_c, + const MatrixIndexT other_dim_a, const float alpha, const float *A, + const MatrixIndexT a_stride, const float beta, float *C, + const MatrixIndexT c_stride) { + cblas_ssyrk(CblasRowMajor, CblasLower, static_cast(trans), + dim_c, other_dim_a, alpha, A, a_stride, beta, C, c_stride); +} + +inline void cblas_Xsyrk( + const MatrixTransposeType trans, const MatrixIndexT dim_c, + const MatrixIndexT other_dim_a, const double alpha, const double *A, + const MatrixIndexT a_stride, const double beta, double *C, + const MatrixIndexT c_stride) { + cblas_dsyrk(CblasRowMajor, CblasLower, static_cast(trans), + dim_c, other_dim_a, alpha, A, a_stride, beta, C, c_stride); +} + +/// matrix-vector multiply using a banded matrix; we always call this +/// with b = 1 meaning we're multiplying by a diagonal matrix. This is used for +/// elementwise multiplication. We miss some of the arguments out of this +/// wrapper. +inline void cblas_Xsbmv1( + const MatrixIndexT dim, + const double *A, + const double alpha, + const double *x, + const double beta, + double *y) { + cblas_dsbmv(CblasRowMajor, CblasLower, dim, 0, alpha, A, + 1, x, 1, beta, y, 1); +} + +inline void cblas_Xsbmv1( + const MatrixIndexT dim, + const float *A, + const float alpha, + const float *x, + const float beta, + float *y) { + cblas_ssbmv(CblasRowMajor, CblasLower, dim, 0, alpha, A, + 1, x, 1, beta, y, 1); +} + +/// This is not really a wrapper for CBLAS as CBLAS does not have this; in future we could +/// extend this somehow. +inline void mul_elements( + const MatrixIndexT dim, + const double *a, + double *b) { // does b *= a, elementwise. + double c1, c2, c3, c4; + MatrixIndexT i; + for (i = 0; i + 4 <= dim; i += 4) { + c1 = a[i] * b[i]; + c2 = a[i+1] * b[i+1]; + c3 = a[i+2] * b[i+2]; + c4 = a[i+3] * b[i+3]; + b[i] = c1; + b[i+1] = c2; + b[i+2] = c3; + b[i+3] = c4; + } + for (; i < dim; i++) + b[i] *= a[i]; +} + +inline void mul_elements( + const MatrixIndexT dim, + const float *a, + float *b) { // does b *= a, elementwise. + float c1, c2, c3, c4; + MatrixIndexT i; + for (i = 0; i + 4 <= dim; i += 4) { + c1 = a[i] * b[i]; + c2 = a[i+1] * b[i+1]; + c3 = a[i+2] * b[i+2]; + c4 = a[i+3] * b[i+3]; + b[i] = c1; + b[i+1] = c2; + b[i+2] = c3; + b[i+3] = c4; + } + for (; i < dim; i++) + b[i] *= a[i]; +} + + + +// add clapack here +#if !defined(HAVE_ATLAS) +inline void clapack_Xtptri(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *result) { + stptri_(const_cast("U"), const_cast("N"), num_rows, Mdata, result); +} +inline void clapack_Xtptri(KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *result) { + dtptri_(const_cast("U"), const_cast("N"), num_rows, Mdata, result); +} +// +inline void clapack_Xgetrf2(KaldiBlasInt *num_rows, KaldiBlasInt *num_cols, + float *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot, + KaldiBlasInt *result) { + sgetrf_(num_rows, num_cols, Mdata, stride, pivot, result); +} +inline void clapack_Xgetrf2(KaldiBlasInt *num_rows, KaldiBlasInt *num_cols, + double *Mdata, KaldiBlasInt *stride, KaldiBlasInt *pivot, + KaldiBlasInt *result) { + dgetrf_(num_rows, num_cols, Mdata, stride, pivot, result); +} + +// +inline void clapack_Xgetri2(KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *stride, + KaldiBlasInt *pivot, float *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + sgetri_(num_rows, Mdata, stride, pivot, p_work, l_work, result); +} +inline void clapack_Xgetri2(KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *stride, + KaldiBlasInt *pivot, double *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + dgetri_(num_rows, Mdata, stride, pivot, p_work, l_work, result); +} +// +inline void clapack_Xgesvd(char *v, char *u, KaldiBlasInt *num_cols, + KaldiBlasInt *num_rows, float *Mdata, KaldiBlasInt *stride, + float *sv, float *Vdata, KaldiBlasInt *vstride, + float *Udata, KaldiBlasInt *ustride, float *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + sgesvd_(v, u, + num_cols, num_rows, Mdata, stride, + sv, Vdata, vstride, Udata, ustride, + p_work, l_work, result); +} +inline void clapack_Xgesvd(char *v, char *u, KaldiBlasInt *num_cols, + KaldiBlasInt *num_rows, double *Mdata, KaldiBlasInt *stride, + double *sv, double *Vdata, KaldiBlasInt *vstride, + double *Udata, KaldiBlasInt *ustride, double *p_work, + KaldiBlasInt *l_work, KaldiBlasInt *result) { + dgesvd_(v, u, + num_cols, num_rows, Mdata, stride, + sv, Vdata, vstride, Udata, ustride, + p_work, l_work, result); +} +// +void inline clapack_Xsptri(KaldiBlasInt *num_rows, float *Mdata, + KaldiBlasInt *ipiv, float *work, KaldiBlasInt *result) { + ssptri_(const_cast("U"), num_rows, Mdata, ipiv, work, result); +} +void inline clapack_Xsptri(KaldiBlasInt *num_rows, double *Mdata, + KaldiBlasInt *ipiv, double *work, KaldiBlasInt *result) { + dsptri_(const_cast("U"), num_rows, Mdata, ipiv, work, result); +} +// +void inline clapack_Xsptrf(KaldiBlasInt *num_rows, float *Mdata, + KaldiBlasInt *ipiv, KaldiBlasInt *result) { + ssptrf_(const_cast("U"), num_rows, Mdata, ipiv, result); +} +void inline clapack_Xsptrf(KaldiBlasInt *num_rows, double *Mdata, + KaldiBlasInt *ipiv, KaldiBlasInt *result) { + dsptrf_(const_cast("U"), num_rows, Mdata, ipiv, result); +} +#else +inline void clapack_Xgetrf(MatrixIndexT num_rows, MatrixIndexT num_cols, + float *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_sgetrf(CblasColMajor, num_rows, num_cols, + Mdata, stride, pivot); +} + +inline void clapack_Xgetrf(MatrixIndexT num_rows, MatrixIndexT num_cols, + double *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_dgetrf(CblasColMajor, num_rows, num_cols, + Mdata, stride, pivot); +} +// +inline int clapack_Xtrtri(int num_rows, float *Mdata, MatrixIndexT stride) { + return clapack_strtri(CblasColMajor, CblasUpper, CblasNonUnit, num_rows, + Mdata, stride); +} + +inline int clapack_Xtrtri(int num_rows, double *Mdata, MatrixIndexT stride) { + return clapack_dtrtri(CblasColMajor, CblasUpper, CblasNonUnit, num_rows, + Mdata, stride); +} +// +inline void clapack_Xgetri(MatrixIndexT num_rows, float *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_sgetri(CblasColMajor, num_rows, Mdata, stride, pivot); +} +inline void clapack_Xgetri(MatrixIndexT num_rows, double *Mdata, MatrixIndexT stride, + int *pivot, int *result) { + *result = clapack_dgetri(CblasColMajor, num_rows, Mdata, stride, pivot); +} +#endif + +} +// namespace kaldi + +#endif diff --git a/speechx/speechx/kaldi/matrix/compressed-matrix.cc b/speechx/speechx/kaldi/matrix/compressed-matrix.cc new file mode 100644 index 00000000..13214b25 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/compressed-matrix.cc @@ -0,0 +1,876 @@ +// matrix/compressed-matrix.cc + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) +// Frantisek Skala, Wei Shi +// 2015 Tom Ko + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "matrix/compressed-matrix.h" +#include + +namespace kaldi { + +//static +MatrixIndexT CompressedMatrix::DataSize(const GlobalHeader &header) { + // Returns size in bytes of the data. + DataFormat format = static_cast(header.format); + if (format == kOneByteWithColHeaders) { + return sizeof(GlobalHeader) + + header.num_cols * (sizeof(PerColHeader) + header.num_rows); + } else if (format == kTwoByte) { + return sizeof(GlobalHeader) + + 2 * header.num_rows * header.num_cols; + } else { + KALDI_ASSERT(format == kOneByte); + return sizeof(GlobalHeader) + + header.num_rows * header.num_cols; + } +} + +// scale all element of matrix by scaling floats +// in GlobalHeader with alpha. +void CompressedMatrix::Scale(float alpha) { + if (data_ != NULL) { + GlobalHeader *h = reinterpret_cast(data_); + // scale the floating point values in each PerColHolder + // and leave all integers the same. + h->min_value *= alpha; + h->range *= alpha; + } +} + +template // static inline +void CompressedMatrix::ComputeGlobalHeader( + const MatrixBase &mat, CompressionMethod method, + GlobalHeader *header) { + if (method == kAutomaticMethod) { + if (mat.NumRows() > 8) method = kSpeechFeature; + else method = kTwoByteAuto; + } + + switch (method) { + case kSpeechFeature: + header->format = static_cast(kOneByteWithColHeaders); // 1. + break; + case kTwoByteAuto: case kTwoByteSignedInteger: + header->format = static_cast(kTwoByte); // 2. + break; + case kOneByteAuto: case kOneByteUnsignedInteger: case kOneByteZeroOne: + header->format = static_cast(kOneByte); // 3. + break; + default: + KALDI_ERR << "Invalid compression type: " + << static_cast(method); + } + + header->num_rows = mat.NumRows(); + header->num_cols = mat.NumCols(); + + // Now compute 'min_value' and 'range'. + switch (method) { + case kSpeechFeature: case kTwoByteAuto: case kOneByteAuto: { + float min_value = mat.Min(), max_value = mat.Max(); + // ensure that max_value is strictly greater than min_value, even if matrix is + // constant; this avoids crashes in ComputeColHeader when compressing speech + // featupres. + if (max_value == min_value) + max_value = min_value + (1.0 + fabs(min_value)); + KALDI_ASSERT(min_value - min_value == 0 && + max_value - max_value == 0 && + "Cannot compress a matrix with Nan's or Inf's"); + + header->min_value = min_value; + header->range = max_value - min_value; + + // we previously checked that max_value != min_value, so their + // difference should be nonzero. + KALDI_ASSERT(header->range > 0.0); + break; + } + case kTwoByteSignedInteger: { + header->min_value = -32768.0; + header->range = 65535.0; + break; + } + case kOneByteUnsignedInteger: { + header->min_value = 0.0; + header->range = 255.0; + break; + } + case kOneByteZeroOne: { + header->min_value = 0.0; + header->range = 1.0; + break; + } + default: + KALDI_ERR << "Unknown compression method = " + << static_cast(method); + } + KALDI_COMPILE_TIME_ASSERT(sizeof(*header) == 20); // otherwise + // something weird is happening and our code probably won't work or + // won't be robust across platforms. +} + +template +void CompressedMatrix::CopyFromMat( + const MatrixBase &mat, CompressionMethod method) { + if (data_ != NULL) { + delete [] static_cast(data_); // call delete [] because was allocated with new float[] + data_ = NULL; + } + if (mat.NumRows() == 0) { return; } // Zero-size matrix stored as zero pointer. + + + GlobalHeader global_header; + ComputeGlobalHeader(mat, method, &global_header); + + int32 data_size = DataSize(global_header); + + data_ = AllocateData(data_size); + + *(reinterpret_cast(data_)) = global_header; + + DataFormat format = static_cast(global_header.format); + if (format == kOneByteWithColHeaders) { + PerColHeader *header_data = + reinterpret_cast(static_cast(data_) + + sizeof(GlobalHeader)); + uint8 *byte_data = + reinterpret_cast(header_data + global_header.num_cols); + + const Real *matrix_data = mat.Data(); + + for (int32 col = 0; col < global_header.num_cols; col++) { + CompressColumn(global_header, + matrix_data + col, mat.Stride(), + global_header.num_rows, + header_data, byte_data); + header_data++; + byte_data += global_header.num_rows; + } + } else if (format == kTwoByte) { + uint16 *data = reinterpret_cast(static_cast(data_) + + sizeof(GlobalHeader)); + int32 num_rows = mat.NumRows(), num_cols = mat.NumCols(); + for (int32 r = 0; r < num_rows; r++) { + const Real *row_data = mat.RowData(r); + for (int32 c = 0; c < num_cols; c++) + data[c] = FloatToUint16(global_header, row_data[c]); + data += num_cols; + } + } else { + KALDI_ASSERT(format == kOneByte); + uint8 *data = reinterpret_cast(static_cast(data_) + + sizeof(GlobalHeader)); + int32 num_rows = mat.NumRows(), num_cols = mat.NumCols(); + for (int32 r = 0; r < num_rows; r++) { + const Real *row_data = mat.RowData(r); + for (int32 c = 0; c < num_cols; c++) + data[c] = FloatToUint8(global_header, row_data[c]); + data += num_cols; + } + } +} + +// Instantiate the template for float and double. +template +void CompressedMatrix::CopyFromMat(const MatrixBase &mat, + CompressionMethod method); + +template +void CompressedMatrix::CopyFromMat(const MatrixBase &mat, + CompressionMethod method); + + +CompressedMatrix::CompressedMatrix( + const CompressedMatrix &cmat, + const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols, + bool allow_padding): data_(NULL) { + int32 old_num_rows = cmat.NumRows(), old_num_cols = cmat.NumCols(); + + if (old_num_rows == 0) { + KALDI_ASSERT(num_rows == 0 && num_cols == 0); + // The empty matrix is stored as a zero pointer. + return; + } + + KALDI_ASSERT(row_offset < old_num_rows); + KALDI_ASSERT(col_offset < old_num_cols); + KALDI_ASSERT(row_offset >= 0 || allow_padding); + KALDI_ASSERT(col_offset >= 0); + KALDI_ASSERT(row_offset + num_rows <= old_num_rows || allow_padding); + KALDI_ASSERT(col_offset + num_cols <= old_num_cols); + + if (num_rows == 0 || num_cols == 0) { return; } + + bool padding_is_used = (row_offset < 0 || + row_offset + num_rows > old_num_rows); + + GlobalHeader new_global_header; + KALDI_COMPILE_TIME_ASSERT(sizeof(new_global_header) == 20); + + GlobalHeader *old_global_header = reinterpret_cast(cmat.Data()); + + new_global_header = *old_global_header; + new_global_header.num_cols = num_cols; + new_global_header.num_rows = num_rows; + + // We don't switch format from 1 -> 2 (in case of size reduction) yet; if this + // is needed, we will do this below by creating a temporary Matrix. + new_global_header.format = old_global_header->format; + + data_ = AllocateData(DataSize(new_global_header)); // allocate memory + *(reinterpret_cast(data_)) = new_global_header; + + + DataFormat format = static_cast(old_global_header->format); + if (format == kOneByteWithColHeaders) { + PerColHeader *old_per_col_header = + reinterpret_cast(old_global_header + 1); + uint8 *old_byte_data = + reinterpret_cast(old_per_col_header + + old_global_header->num_cols); + PerColHeader *new_per_col_header = + reinterpret_cast( + reinterpret_cast(data_) + 1); + + memcpy(new_per_col_header, old_per_col_header + col_offset, + sizeof(PerColHeader) * num_cols); + + uint8 *new_byte_data = + reinterpret_cast(new_per_col_header + num_cols); + if (!padding_is_used) { + uint8 *old_start_of_subcol = + old_byte_data + row_offset + (col_offset * old_num_rows), + *new_start_of_col = new_byte_data; + for (int32 i = 0; i < num_cols; i++) { + memcpy(new_start_of_col, old_start_of_subcol, num_rows); + new_start_of_col += num_rows; + old_start_of_subcol += old_num_rows; + } + } else { + uint8 *old_start_of_col = + old_byte_data + (col_offset * old_num_rows), + *new_start_of_col = new_byte_data; + for (int32 i = 0; i < num_cols; i++) { + + for (int32 j = 0; j < num_rows; j++) { + int32 old_j = j + row_offset; + if (old_j < 0) old_j = 0; + else if (old_j >= old_num_rows) old_j = old_num_rows - 1; + new_start_of_col[j] = old_start_of_col[old_j]; + } + new_start_of_col += num_rows; + old_start_of_col += old_num_rows; + } + } + } else if (format == kTwoByte) { + const uint16 *old_data = + reinterpret_cast(old_global_header + 1); + uint16 *new_row_data = + reinterpret_cast(reinterpret_cast(data_) + 1); + + for (int32 row = 0; row < num_rows; row++) { + int32 old_row = row + row_offset; + // The next two lines are only relevant if padding_is_used. + if (old_row < 0) old_row = 0; + else if (old_row >= old_num_rows) old_row = old_num_rows - 1; + const uint16 *old_row_data = + old_data + col_offset + (old_num_cols * old_row); + memcpy(new_row_data, old_row_data, sizeof(uint16) * num_cols); + new_row_data += num_cols; + } + } else { + KALDI_ASSERT(format == kOneByte); + const uint8 *old_data = + reinterpret_cast(old_global_header + 1); + uint8 *new_row_data = + reinterpret_cast(reinterpret_cast(data_) + 1); + + for (int32 row = 0; row < num_rows; row++) { + int32 old_row = row + row_offset; + // The next two lines are only relevant if padding_is_used. + if (old_row < 0) old_row = 0; + else if (old_row >= old_num_rows) old_row = old_num_rows - 1; + const uint8 *old_row_data = + old_data + col_offset + (old_num_cols * old_row); + memcpy(new_row_data, old_row_data, sizeof(uint8) * num_cols); + new_row_data += num_cols; + } + } + + if (num_rows < 8 && format == kOneByteWithColHeaders) { + // format was 1 but we want it to be 2 -> create a temporary + // Matrix (uncompress), re-compress, and swap. + // This gives us almost exact reconstruction while saving + // memory (the elements take more space but there will be + // no per-column headers). + Matrix temp(this->NumRows(), this->NumCols(), + kUndefined); + this->CopyToMat(&temp); + CompressedMatrix temp_cmat(temp, kTwoByteAuto); + this->Swap(&temp_cmat); + } +} + + +template +CompressedMatrix &CompressedMatrix::operator =(const MatrixBase &mat) { + this->CopyFromMat(mat); + return *this; +} + +// Instantiate the template for float and double. +template +CompressedMatrix& CompressedMatrix::operator =(const MatrixBase &mat); + +template +CompressedMatrix& CompressedMatrix::operator =(const MatrixBase &mat); + +inline uint16 CompressedMatrix::FloatToUint16( + const GlobalHeader &global_header, + float value) { + float f = (value - global_header.min_value) / + global_header.range; + if (f > 1.0) f = 1.0; // Note: this should not happen. + if (f < 0.0) f = 0.0; // Note: this should not happen. + return static_cast(f * 65535 + 0.499); // + 0.499 is to + // round to closest int; avoids bias. +} + + +inline uint8 CompressedMatrix::FloatToUint8( + const GlobalHeader &global_header, + float value) { + float f = (value - global_header.min_value) / + global_header.range; + if (f > 1.0) f = 1.0; // Note: this should not happen. + if (f < 0.0) f = 0.0; // Note: this should not happen. + return static_cast(f * 255 + 0.499); // + 0.499 is to + // round to closest int; avoids bias. +} + + +inline float CompressedMatrix::Uint16ToFloat( + const GlobalHeader &global_header, + uint16 value) { + // the constant 1.52590218966964e-05 is 1/65535. + return global_header.min_value + + global_header.range * 1.52590218966964e-05F * value; +} + +template // static +void CompressedMatrix::ComputeColHeader( + const GlobalHeader &global_header, + const Real *data, MatrixIndexT stride, + int32 num_rows, CompressedMatrix::PerColHeader *header) { + KALDI_ASSERT(num_rows > 0); + std::vector sdata(num_rows); // the sorted data. + for (size_t i = 0, size = sdata.size(); i < size; i++) + sdata[i] = data[i*stride]; + + if (num_rows >= 5) { + int quarter_nr = num_rows/4; + // std::sort(sdata.begin(), sdata.end()); + // The elements at positions 0, quarter_nr, + // 3*quarter_nr, and num_rows-1 need to be in sorted order. + std::nth_element(sdata.begin(), sdata.begin() + quarter_nr, sdata.end()); + // Now, sdata.begin() + quarter_nr contains the element that would appear + // in sorted order, in that position. + std::nth_element(sdata.begin(), sdata.begin(), sdata.begin() + quarter_nr); + // Now, sdata.begin() and sdata.begin() + quarter_nr contain the elements + // that would appear at those positions in sorted order. + std::nth_element(sdata.begin() + quarter_nr + 1, + sdata.begin() + (3*quarter_nr), sdata.end()); + // Now, sdata.begin(), sdata.begin() + quarter_nr, and sdata.begin() + + // 3*quarter_nr, contain the elements that would appear at those positions + // in sorted order. + std::nth_element(sdata.begin() + (3*quarter_nr) + 1, sdata.end() - 1, + sdata.end()); + // Now, sdata.begin(), sdata.begin() + quarter_nr, and sdata.begin() + + // 3*quarter_nr, and sdata.end() - 1, contain the elements that would appear + // at those positions in sorted order. + + header->percentile_0 = + std::min(FloatToUint16(global_header, sdata[0]), 65532); + header->percentile_25 = + std::min( + std::max( + FloatToUint16(global_header, sdata[quarter_nr]), + header->percentile_0 + static_cast(1)), 65533); + header->percentile_75 = + std::min( + std::max( + FloatToUint16(global_header, sdata[3*quarter_nr]), + header->percentile_25 + static_cast(1)), 65534); + header->percentile_100 = std::max( + FloatToUint16(global_header, sdata[num_rows-1]), + header->percentile_75 + static_cast(1)); + + } else { // handle this pathological case. + std::sort(sdata.begin(), sdata.end()); + // Note: we know num_rows is at least 1. + header->percentile_0 = + std::min(FloatToUint16(global_header, sdata[0]), + 65532); + if (num_rows > 1) + header->percentile_25 = + std::min( + std::max(FloatToUint16(global_header, sdata[1]), + header->percentile_0 + 1), 65533); + else + header->percentile_25 = header->percentile_0 + 1; + if (num_rows > 2) + header->percentile_75 = + std::min( + std::max(FloatToUint16(global_header, sdata[2]), + header->percentile_25 + 1), 65534); + else + header->percentile_75 = header->percentile_25 + 1; + if (num_rows > 3) + header->percentile_100 = + std::max(FloatToUint16(global_header, sdata[3]), + header->percentile_75 + 1); + else + header->percentile_100 = header->percentile_75 + 1; + } +} + +// static +inline uint8 CompressedMatrix::FloatToChar( + float p0, float p25, float p75, float p100, + float value) { + int ans; + if (value < p25) { // range [ p0, p25 ) covered by + // characters 0 .. 64. We round to the closest int. + float f = (value - p0) / (p25 - p0); + ans = static_cast(f * 64 + 0.5); + // Note: the checks on the next two lines + // are necessary in pathological cases when all the elements in a row + // are the same and the percentile_* values are separated by one. + if (ans < 0) ans = 0; + if (ans > 64) ans = 64; + } else if (value < p75) { // range [ p25, p75 )covered + // by characters 64 .. 192. We round to the closest int. + float f = (value - p25) / (p75 - p25); + ans = 64 + static_cast(f * 128 + 0.5); + if (ans < 64) ans = 64; + if (ans > 192) ans = 192; + } else { // range [ p75, p100 ] covered by + // characters 192 .. 255. Note: this last range + // has fewer characters than the left range, because + // we go up to 255, not 256. + float f = (value - p75) / (p100 - p75); + ans = 192 + static_cast(f * 63 + 0.5); + if (ans < 192) ans = 192; + if (ans > 255) ans = 255; + } + return static_cast(ans); +} + + +// static +inline float CompressedMatrix::CharToFloat( + float p0, float p25, float p75, float p100, + uint8 value) { + if (value <= 64) { + return p0 + (p25 - p0) * value * (1/64.0); + } else if (value <= 192) { + return p25 + (p75 - p25) * (value - 64) * (1/128.0); + } else { + return p75 + (p100 - p75) * (value - 192) * (1/63.0); + } +} + + +template // static +void CompressedMatrix::CompressColumn( + const GlobalHeader &global_header, + const Real *data, MatrixIndexT stride, + int32 num_rows, CompressedMatrix::PerColHeader *header, + uint8 *byte_data) { + ComputeColHeader(global_header, data, stride, + num_rows, header); + + float p0 = Uint16ToFloat(global_header, header->percentile_0), + p25 = Uint16ToFloat(global_header, header->percentile_25), + p75 = Uint16ToFloat(global_header, header->percentile_75), + p100 = Uint16ToFloat(global_header, header->percentile_100); + + for (int32 i = 0; i < num_rows; i++) { + Real this_data = data[i * stride]; + byte_data[i] = FloatToChar(p0, p25, p75, p100, this_data); + } +} + +// static +void* CompressedMatrix::AllocateData(int32 num_bytes) { + KALDI_ASSERT(num_bytes > 0); + KALDI_COMPILE_TIME_ASSERT(sizeof(float) == 4); + // round size up to nearest number of floats. + return reinterpret_cast(new float[(num_bytes/3) + 4]); +} + +void CompressedMatrix::Write(std::ostream &os, bool binary) const { + if (binary) { // Binary-mode write: + if (data_ != NULL) { + GlobalHeader &h = *reinterpret_cast(data_); + DataFormat format = static_cast(h.format); + if (format == kOneByteWithColHeaders) { + WriteToken(os, binary, "CM"); + } else if (format == kTwoByte) { + WriteToken(os, binary, "CM2"); + } else if (format == kOneByte) { + WriteToken(os, binary, "CM3"); + } + MatrixIndexT size = DataSize(h); // total size of data in data_ + // We don't write out the "int32 format", hence the + 4, - 4. + os.write(reinterpret_cast(data_) + 4, size - 4); + } else { // special case: where data_ == NULL, we treat it as an empty + // matrix. + WriteToken(os, binary, "CM"); + GlobalHeader h; + h.range = h.min_value = 0.0; + h.num_rows = h.num_cols = 0; + os.write(reinterpret_cast(&h), sizeof(h)); + } + } else { + // In text mode, just use the same format as a regular matrix. + // This is not compressed. + Matrix temp_mat(this->NumRows(), this->NumCols(), + kUndefined); + this->CopyToMat(&temp_mat); + temp_mat.Write(os, binary); + } + if (os.fail()) + KALDI_ERR << "Error writing compressed matrix to stream."; +} + +void CompressedMatrix::Read(std::istream &is, bool binary) { + if (data_ != NULL) { + delete [] (static_cast(data_)); + data_ = NULL; + } + if (binary) { + int peekval = Peek(is, binary); + if (peekval == 'C') { + std::string tok; // Should be CM (format 1) or CM2 (format 2) + ReadToken(is, binary, &tok); + GlobalHeader h; + if (tok == "CM") { h.format = 1; } // kOneByteWithColHeaders + else if (tok == "CM2") { h.format = 2; } // kTwoByte + else if (tok == "CM3") { h.format = 3; } // kOneByte + else { + KALDI_ERR << "Unexpected token " << tok << ", expecting CM, CM2 or CM3"; + } + // don't read the "format" -> hence + 4, - 4. + is.read(reinterpret_cast(&h) + 4, sizeof(h) - 4); + if (is.fail()) + KALDI_ERR << "Failed to read header"; + if (h.num_cols == 0) // empty matrix. + return; + int32 size = DataSize(h), remaining_size = size - sizeof(GlobalHeader); + data_ = AllocateData(size); + *(reinterpret_cast(data_)) = h; + is.read(reinterpret_cast(data_) + sizeof(GlobalHeader), + remaining_size); + } else { + // Assume that what we're reading is a regular Matrix. This might be the + // case if you changed your code, making a Matrix into a CompressedMatrix, + // and you want back-compatibility for reading. + Matrix M; + M.Read(is, binary); // This will crash if it was not a Matrix. + this->CopyFromMat(M); + } + } else { // Text-mode read. In this case you don't get to + // choose the compression type. Anyway this branch would only + // be taken when debugging. + Matrix temp; + temp.Read(is, binary); + this->CopyFromMat(temp); + } + if (is.fail()) + KALDI_ERR << "Failed to read data."; +} + +template +void CompressedMatrix::CopyToMat(MatrixBase *mat, + MatrixTransposeType trans) const { + if (trans == kTrans) { + Matrix temp(this->NumCols(), this->NumRows()); + CopyToMat(&temp, kNoTrans); + mat->CopyFromMat(temp, kTrans); + return; + } + + if (data_ == NULL) { + KALDI_ASSERT(mat->NumRows() == 0); + KALDI_ASSERT(mat->NumCols() == 0); + return; + } + GlobalHeader *h = reinterpret_cast(data_); + int32 num_cols = h->num_cols, num_rows = h->num_rows; + KALDI_ASSERT(mat->NumRows() == num_rows); + KALDI_ASSERT(mat->NumCols() == num_cols); + + DataFormat format = static_cast(h->format); + if (format == kOneByteWithColHeaders) { + PerColHeader *per_col_header = reinterpret_cast(h+1); + uint8 *byte_data = reinterpret_cast(per_col_header + + h->num_cols); + for (int32 i = 0; i < num_cols; i++, per_col_header++) { + float p0 = Uint16ToFloat(*h, per_col_header->percentile_0), + p25 = Uint16ToFloat(*h, per_col_header->percentile_25), + p75 = Uint16ToFloat(*h, per_col_header->percentile_75), + p100 = Uint16ToFloat(*h, per_col_header->percentile_100); + for (int32 j = 0; j < num_rows; j++, byte_data++) { + float f = CharToFloat(p0, p25, p75, p100, *byte_data); + (*mat)(j, i) = f; + } + } + } else if (format == kTwoByte) { + const uint16 *data = reinterpret_cast(h + 1); + float min_value = h->min_value, + increment = h->range * (1.0 / 65535.0); + for (int32 i = 0; i < num_rows; i++) { + Real *row_data = mat->RowData(i); + for (int32 j = 0; j < num_cols; j++) + row_data[j] = min_value + data[j] * increment; + data += num_cols; + } + } else { + KALDI_ASSERT(format == kOneByte); + float min_value = h->min_value, increment = h->range * (1.0 / 255.0); + + const uint8 *data = reinterpret_cast(h + 1); + for (int32 i = 0; i < num_rows; i++) { + Real *row_data = mat->RowData(i); + for (int32 j = 0; j < num_cols; j++) + row_data[j] = min_value + data[j] * increment; + data += num_cols; + } + } +} + +// Instantiate the template for float and double. +template +void CompressedMatrix::CopyToMat(MatrixBase *mat, + MatrixTransposeType trans) const; +template +void CompressedMatrix::CopyToMat(MatrixBase *mat, + MatrixTransposeType trans) const; + +template +void CompressedMatrix::CopyRowToVec(MatrixIndexT row, + VectorBase *v) const { + KALDI_ASSERT(row < this->NumRows()); + KALDI_ASSERT(row >= 0); + KALDI_ASSERT(v->Dim() == this->NumCols()); + + GlobalHeader *h = reinterpret_cast(data_); + DataFormat format = static_cast(h->format); + if (format == kOneByteWithColHeaders) { + PerColHeader *per_col_header = reinterpret_cast(h+1); + uint8 *byte_data = reinterpret_cast(per_col_header + + h->num_cols); + byte_data += row; // point to first value we are interested in + for (int32 i = 0; i < h->num_cols; + i++, per_col_header++, byte_data += h->num_rows) { + float p0 = Uint16ToFloat(*h, per_col_header->percentile_0), + p25 = Uint16ToFloat(*h, per_col_header->percentile_25), + p75 = Uint16ToFloat(*h, per_col_header->percentile_75), + p100 = Uint16ToFloat(*h, per_col_header->percentile_100); + float f = CharToFloat(p0, p25, p75, p100, *byte_data); + (*v)(i) = f; + } + } else if (format == kTwoByte) { + int32 num_cols = h->num_cols; + float min_value = h->min_value, + increment = h->range * (1.0 / 65535.0); + const uint16 *row_data = reinterpret_cast(h + 1) + (num_cols * row); + Real *v_data = v->Data(); + for (int32 c = 0; c < num_cols; c++) + v_data[c] = min_value + row_data[c] * increment; + } else { + KALDI_ASSERT(format == kOneByte); + int32 num_cols = h->num_cols; + float min_value = h->min_value, + increment = h->range * (1.0 / 255.0); + const uint8 *row_data = reinterpret_cast(h + 1) + (num_cols * row); + Real *v_data = v->Data(); + for (int32 c = 0; c < num_cols; c++) + v_data[c] = min_value + row_data[c] * increment; + } +} + +template +void CompressedMatrix::CopyColToVec(MatrixIndexT col, + VectorBase *v) const { + KALDI_ASSERT(col < this->NumCols()); + KALDI_ASSERT(col >= 0); + KALDI_ASSERT(v->Dim() == this->NumRows()); + + GlobalHeader *h = reinterpret_cast(data_); + + DataFormat format = static_cast(h->format); + if (format == kOneByteWithColHeaders) { + PerColHeader *per_col_header = reinterpret_cast(h+1); + uint8 *byte_data = reinterpret_cast(per_col_header + + h->num_cols); + byte_data += col*h->num_rows; // point to first value in the column we want + per_col_header += col; + float p0 = Uint16ToFloat(*h, per_col_header->percentile_0), + p25 = Uint16ToFloat(*h, per_col_header->percentile_25), + p75 = Uint16ToFloat(*h, per_col_header->percentile_75), + p100 = Uint16ToFloat(*h, per_col_header->percentile_100); + for (int32 i = 0; i < h->num_rows; i++, byte_data++) { + float f = CharToFloat(p0, p25, p75, p100, *byte_data); + (*v)(i) = f; + } + } else if (format == kTwoByte) { + int32 num_rows = h->num_rows, num_cols = h->num_cols; + float min_value = h->min_value, + increment = h->range * (1.0 / 65535.0); + const uint16 *col_data = reinterpret_cast(h + 1) + col; + Real *v_data = v->Data(); + for (int32 r = 0; r < num_rows; r++) + v_data[r] = min_value + increment * col_data[r * num_cols]; + } else { + KALDI_ASSERT(format == kOneByte); + int32 num_rows = h->num_rows, num_cols = h->num_cols; + float min_value = h->min_value, + increment = h->range * (1.0 / 255.0); + const uint8 *col_data = reinterpret_cast(h + 1) + col; + Real *v_data = v->Data(); + for (int32 r = 0; r < num_rows; r++) + v_data[r] = min_value + increment * col_data[r * num_cols]; + } +} + +// instantiate the templates. +template void +CompressedMatrix::CopyColToVec(MatrixIndexT, VectorBase *) const; +template void +CompressedMatrix::CopyColToVec(MatrixIndexT, VectorBase *) const; +template void +CompressedMatrix::CopyRowToVec(MatrixIndexT, VectorBase *) const; +template void +CompressedMatrix::CopyRowToVec(MatrixIndexT, VectorBase *) const; + +template +void CompressedMatrix::CopyToMat(int32 row_offset, + int32 col_offset, + MatrixBase *dest) const { + KALDI_PARANOID_ASSERT(row_offset < this->NumRows()); + KALDI_PARANOID_ASSERT(col_offset < this->NumCols()); + KALDI_PARANOID_ASSERT(row_offset >= 0); + KALDI_PARANOID_ASSERT(col_offset >= 0); + KALDI_ASSERT(row_offset+dest->NumRows() <= this->NumRows()); + KALDI_ASSERT(col_offset+dest->NumCols() <= this->NumCols()); + // everything is OK + GlobalHeader *h = reinterpret_cast(data_); + int32 num_rows = h->num_rows, num_cols = h->num_cols, + tgt_cols = dest->NumCols(), tgt_rows = dest->NumRows(); + + DataFormat format = static_cast(h->format); + if (format == kOneByteWithColHeaders) { + PerColHeader *per_col_header = reinterpret_cast(h+1); + uint8 *byte_data = reinterpret_cast(per_col_header + + h->num_cols); + + uint8 *start_of_subcol = byte_data+row_offset; // skip appropriate + // number of columns + start_of_subcol += col_offset*num_rows; // skip appropriate number of rows + + per_col_header += col_offset; // skip the appropriate number of headers + + for (int32 i = 0; + i < tgt_cols; + i++, per_col_header++, start_of_subcol+=num_rows) { + byte_data = start_of_subcol; + float p0 = Uint16ToFloat(*h, per_col_header->percentile_0), + p25 = Uint16ToFloat(*h, per_col_header->percentile_25), + p75 = Uint16ToFloat(*h, per_col_header->percentile_75), + p100 = Uint16ToFloat(*h, per_col_header->percentile_100); + for (int32 j = 0; j < tgt_rows; j++, byte_data++) { + float f = CharToFloat(p0, p25, p75, p100, *byte_data); + (*dest)(j, i) = f; + } + } + } else if (format == kTwoByte) { + const uint16 *data = reinterpret_cast(h+1) + col_offset + + (num_cols * row_offset); + float min_value = h->min_value, + increment = h->range * (1.0 / 65535.0); + + for (int32 row = 0; row < tgt_rows; row++) { + Real *dest_row = dest->RowData(row); + for (int32 col = 0; col < tgt_cols; col++) + dest_row[col] = min_value + increment * data[col]; + data += num_cols; + } + } else { + KALDI_ASSERT(format == kOneByte); + const uint8 *data = reinterpret_cast(h+1) + col_offset + + (num_cols * row_offset); + float min_value = h->min_value, + increment = h->range * (1.0 / 255.0); + for (int32 row = 0; row < tgt_rows; row++) { + Real *dest_row = dest->RowData(row); + for (int32 col = 0; col < tgt_cols; col++) + dest_row[col] = min_value + increment * data[col]; + data += num_cols; + } + } +} + +// instantiate the templates. +template void CompressedMatrix::CopyToMat(int32, + int32, + MatrixBase *dest) const; +template void CompressedMatrix::CopyToMat(int32, + int32, + MatrixBase *dest) const; + +void CompressedMatrix::Clear() { + if (data_ != NULL) { + delete [] static_cast(data_); + data_ = NULL; + } +} + +CompressedMatrix::CompressedMatrix(const CompressedMatrix &mat): data_(NULL) { + *this = mat; // use assignment operator. +} + +CompressedMatrix &CompressedMatrix::operator = (const CompressedMatrix &mat) { + Clear(); // now this->data_ == NULL. + if (mat.data_ != NULL) { + MatrixIndexT data_size = DataSize(*static_cast(mat.data_)); + data_ = AllocateData(data_size); + memcpy(static_cast(data_), + static_cast(mat.data_), + data_size); + } + return *this; +} + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/compressed-matrix.h b/speechx/speechx/kaldi/matrix/compressed-matrix.h new file mode 100644 index 00000000..78105b9b --- /dev/null +++ b/speechx/speechx/kaldi/matrix/compressed-matrix.h @@ -0,0 +1,283 @@ +// matrix/compressed-matrix.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) +// Frantisek Skala, Wei Shi + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_COMPRESSED_MATRIX_H_ +#define KALDI_MATRIX_COMPRESSED_MATRIX_H_ 1 + +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// \addtogroup matrix_group +/// @{ + + + +/* + The enum CompressionMethod is used when creating a CompressedMatrix (a lossily + compressed matrix) from a regular Matrix. It dictates how we choose the + compressed format and how we choose the ranges of floats that are represented + by particular integers. + + kAutomaticMethod = 1 This is the default when you don't specify the + compression method. It is a shorthand for using + kSpeechFeature if the num-rows is more than 8, and + kTwoByteAuto otherwise. + kSpeechFeature = 2 This is the most complicated of the compression methods, + and was designed for speech features which have a roughly + Gaussian distribution with different ranges for each + dimension. Each element is stored in one byte, but there + is an 8-byte header per column; the spacing of the + integer values is not uniform but is in 3 ranges. + kTwoByteAuto = 3 Each element is stored in two bytes as a uint16, with + the representable range of values chosen automatically + with the minimum and maximum elements of the matrix as + its edges. + kTwoByteSignedInteger = 4 + Each element is stored in two bytes as a uint16, with + the representable range of value chosen to coincide with + what you'd get if you stored signed integers, i.e. + [-32768.0, 32767.0]. Suitable for waveform data that + was previously stored as 16-bit PCM. + kOneByteAuto = 5 Each element is stored in one byte as a uint8, with the + representable range of values chosen automatically with + the minimum and maximum elements of the matrix as its + edges. + kOneByteUnsignedInteger = 6 Each element is stored in + one byte as a uint8, with the representable range of + values equal to [0.0, 255.0]. + kOneByteZeroOne = 7 Each element is stored in + one byte as a uint8, with the representable range of + values equal to [0.0, 1.0]. Suitable for image data + that has previously been compressed as int8. + + // We can add new methods here as needed: if they just imply different ways + // of selecting the min_value and range, and a num-bytes = 1 or 2, they will + // be trivial to implement. +*/ +enum CompressionMethod { + kAutomaticMethod = 1, + kSpeechFeature = 2, + kTwoByteAuto = 3, + kTwoByteSignedInteger = 4, + kOneByteAuto = 5, + kOneByteUnsignedInteger = 6, + kOneByteZeroOne = 7 +}; + + +/* + This class does lossy compression of a matrix. It supports various compression + methods, see enum CompressionMethod. +*/ + +class CompressedMatrix { + public: + CompressedMatrix(): data_(NULL) { } + + ~CompressedMatrix() { Clear(); } + + template + explicit CompressedMatrix(const MatrixBase &mat, + CompressionMethod method = kAutomaticMethod): + data_(NULL) { CopyFromMat(mat, method); } + + /// Initializer that can be used to select part of an existing + /// CompressedMatrix without un-compressing and re-compressing (note: unlike + /// similar initializers for class Matrix, it doesn't point to the same memory + /// location). + /// + /// This creates a CompressedMatrix with the size (num_rows, num_cols) + /// starting at (row_offset, col_offset). + /// + /// If you specify allow_padding = true, + /// it is permitted to have row_offset < 0 and + /// row_offset + num_rows > mat.NumRows(), and the result will contain + /// repeats of the first and last rows of 'mat' as necessary. + CompressedMatrix(const CompressedMatrix &mat, + const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols, + bool allow_padding = false); + + void *Data() const { return this->data_; } + + /// This will resize *this and copy the contents of mat to *this. + template + void CopyFromMat(const MatrixBase &mat, + CompressionMethod method = kAutomaticMethod); + + CompressedMatrix(const CompressedMatrix &mat); + + CompressedMatrix &operator = (const CompressedMatrix &mat); // assignment operator. + + template + CompressedMatrix &operator = (const MatrixBase &mat); // assignment operator. + + /// Copies contents to matrix. Note: mat must have the correct size. + /// The kTrans case uses a temporary. + template + void CopyToMat(MatrixBase *mat, + MatrixTransposeType trans = kNoTrans) const; + + void Write(std::ostream &os, bool binary) const; + + void Read(std::istream &is, bool binary); + + /// Returns number of rows (or zero for emtpy matrix). + inline MatrixIndexT NumRows() const { return (data_ == NULL) ? 0 : + (*reinterpret_cast(data_)).num_rows; } + + /// Returns number of columns (or zero for emtpy matrix). + inline MatrixIndexT NumCols() const { return (data_ == NULL) ? 0 : + (*reinterpret_cast(data_)).num_cols; } + + /// Copies row #row of the matrix into vector v. + /// Note: v must have same size as #cols. + template + void CopyRowToVec(MatrixIndexT row, VectorBase *v) const; + + /// Copies column #col of the matrix into vector v. + /// Note: v must have same size as #rows. + template + void CopyColToVec(MatrixIndexT col, VectorBase *v) const; + + /// Copies submatrix of compressed matrix into matrix dest. + /// Submatrix starts at row row_offset and column column_offset and its size + /// is defined by size of provided matrix dest + template + void CopyToMat(int32 row_offset, + int32 column_offset, + MatrixBase *dest) const; + + void Swap(CompressedMatrix *other) { std::swap(data_, other->data_); } + + void Clear(); + + /// scales all elements of matrix by alpha. + /// It scales the floating point values in GlobalHeader by alpha. + void Scale(float alpha); + + friend class Matrix; + friend class Matrix; + private: + + // This enum describes the different compressed-data formats: these are + // distinct from the compression methods although all of the methods apart + // from kAutomaticMethod dictate a particular compressed-data format. + // + // kOneByteWithColHeaders means there is a GlobalHeader and each + // column has a PerColHeader; the actual data is stored in + // one byte per element, in column-major order (the mapping + // from integers to floats is a little complicated). + // kTwoByte means there is a global header but no PerColHeader; + // the actual data is stored in two bytes per element in + // row-major order; it's decompressed as: + // uint16 i; GlobalHeader g; + // float f = g.min_value + i * (g.range / 65535.0) + // kOneByte means there is a global header but not PerColHeader; + // the data is stored in one byte per element in row-major + // order and is decompressed as: + // uint8 i; GlobalHeader g; + // float f = g.min_value + i * (g.range / 255.0) + enum DataFormat { + kOneByteWithColHeaders = 1, + kTwoByte = 2, + kOneByte = 3 + }; + + + // allocates data using new [], ensures byte alignment + // sufficient for float. + static void *AllocateData(int32 num_bytes); + + struct GlobalHeader { + int32 format; // Represents the enum DataFormat. + float min_value; // min_value and range represent the ranges of the integer + // data in the kTwoByte and kOneByte formats, and the + // range of the PerColHeader uint16's in the + // kOneByteWithColheaders format. + float range; + int32 num_rows; + int32 num_cols; + }; + + // This function computes the global header for compressing this data. + template + static inline void ComputeGlobalHeader(const MatrixBase &mat, + CompressionMethod method, + GlobalHeader *header); + + + // The number of bytes we need to request when allocating 'data_'. + static MatrixIndexT DataSize(const GlobalHeader &header); + + // This struct is only used in format kOneByteWithColHeaders. + struct PerColHeader { + uint16 percentile_0; + uint16 percentile_25; + uint16 percentile_75; + uint16 percentile_100; + }; + + template + static void CompressColumn(const GlobalHeader &global_header, + const Real *data, MatrixIndexT stride, + int32 num_rows, PerColHeader *header, + uint8 *byte_data); + template + static void ComputeColHeader(const GlobalHeader &global_header, + const Real *data, MatrixIndexT stride, + int32 num_rows, PerColHeader *header); + + static inline uint16 FloatToUint16(const GlobalHeader &global_header, + float value); + + // this is used only in the kOneByte compression format. + static inline uint8 FloatToUint8(const GlobalHeader &global_header, + float value); + + static inline float Uint16ToFloat(const GlobalHeader &global_header, + uint16 value); + + // this is used only in the kOneByteWithColHeaders compression format. + static inline uint8 FloatToChar(float p0, float p25, + float p75, float p100, + float value); + + // this is used only in the kOneByteWithColHeaders compression format. + static inline float CharToFloat(float p0, float p25, + float p75, float p100, + uint8 value); + + void *data_; // first GlobalHeader, then PerColHeader (repeated), then + // the byte data for each column (repeated). Note: don't intersperse + // the byte data with the PerColHeaders, because of alignment issues. + +}; + +/// @} end of \addtogroup matrix_group + + +} // namespace kaldi + + +#endif // KALDI_MATRIX_COMPRESSED_MATRIX_H_ diff --git a/speechx/speechx/kaldi/matrix/jama-eig.h b/speechx/speechx/kaldi/matrix/jama-eig.h new file mode 100644 index 00000000..92d8c27e --- /dev/null +++ b/speechx/speechx/kaldi/matrix/jama-eig.h @@ -0,0 +1,924 @@ +// matrix/jama-eig.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// This file consists of a port and modification of materials from +// JAMA: A Java Matrix Package +// under the following notice: This software is a cooperative product of +// The MathWorks and the National Institute of Standards and Technology (NIST) +// which has been released to the public. This notice and the original code are +// available at http://math.nist.gov/javanumerics/jama/domain.notice + + + +#ifndef KALDI_MATRIX_JAMA_EIG_H_ +#define KALDI_MATRIX_JAMA_EIG_H_ 1 + +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +// This class is not to be used externally. See the Eig function in the Matrix +// class in kaldi-matrix.h. This is the external interface. + +template class EigenvalueDecomposition { + // This class is based on the EigenvalueDecomposition class from the JAMA + // library (version 1.0.2). + public: + EigenvalueDecomposition(const MatrixBase &A); + + ~EigenvalueDecomposition(); // free memory. + + void GetV(MatrixBase *V_out) { // V is what we call P externally; it's the matrix of + // eigenvectors. + KALDI_ASSERT(V_out->NumRows() == static_cast(n_) + && V_out->NumCols() == static_cast(n_)); + for (int i = 0; i < n_; i++) + for (int j = 0; j < n_; j++) + (*V_out)(i, j) = V(i, j); // V(i, j) is member function. + } + void GetRealEigenvalues(VectorBase *r_out) { + // returns real part of eigenvalues. + KALDI_ASSERT(r_out->Dim() == static_cast(n_)); + for (int i = 0; i < n_; i++) + (*r_out)(i) = d_[i]; + } + void GetImagEigenvalues(VectorBase *i_out) { + // returns imaginary part of eigenvalues. + KALDI_ASSERT(i_out->Dim() == static_cast(n_)); + for (int i = 0; i < n_; i++) + (*i_out)(i) = e_[i]; + } + private: + + inline Real &H(int r, int c) { return H_[r*n_ + c]; } + inline Real &V(int r, int c) { return V_[r*n_ + c]; } + + // complex division + inline static void cdiv(Real xr, Real xi, Real yr, Real yi, Real *cdivr, Real *cdivi) { + Real r, d; + if (std::abs(yr) > std::abs(yi)) { + r = yi/yr; + d = yr + r*yi; + *cdivr = (xr + r*xi)/d; + *cdivi = (xi - r*xr)/d; + } else { + r = yr/yi; + d = yi + r*yr; + *cdivr = (r*xr + xi)/d; + *cdivi = (r*xi - xr)/d; + } + } + + // Nonsymmetric reduction from Hessenberg to real Schur form. + void Hqr2 (); + + + int n_; // matrix dimension. + + Real *d_, *e_; // real and imaginary parts of eigenvalues. + Real *V_; // the eigenvectors (P in our external notation) + Real *H_; // the nonsymmetric Hessenberg form. + Real *ort_; // working storage for nonsymmetric algorithm. + + // Symmetric Householder reduction to tridiagonal form. + void Tred2 (); + + // Symmetric tridiagonal QL algorithm. + void Tql2 (); + + // Nonsymmetric reduction to Hessenberg form. + void Orthes (); + +}; + +template class EigenvalueDecomposition; // force instantiation. +template class EigenvalueDecomposition; // force instantiation. + +template void EigenvalueDecomposition::Tred2() { + // This is derived from the Algol procedures tred2 by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + for (int j = 0; j < n_; j++) { + d_[j] = V(n_-1, j); + } + + // Householder reduction to tridiagonal form. + + for (int i = n_-1; i > 0; i--) { + + // Scale to avoid under/overflow. + + Real scale = 0.0; + Real h = 0.0; + for (int k = 0; k < i; k++) { + scale = scale + std::abs(d_[k]); + } + if (scale == 0.0) { + e_[i] = d_[i-1]; + for (int j = 0; j < i; j++) { + d_[j] = V(i-1, j); + V(i, j) = 0.0; + V(j, i) = 0.0; + } + } else { + + // Generate Householder vector. + + for (int k = 0; k < i; k++) { + d_[k] /= scale; + h += d_[k] * d_[k]; + } + Real f = d_[i-1]; + Real g = std::sqrt(h); + if (f > 0) { + g = -g; + } + e_[i] = scale * g; + h = h - f * g; + d_[i-1] = f - g; + for (int j = 0; j < i; j++) { + e_[j] = 0.0; + } + + // Apply similarity transformation to remaining columns. + + for (int j = 0; j < i; j++) { + f = d_[j]; + V(j, i) = f; + g =e_[j] + V(j, j) * f; + for (int k = j+1; k <= i-1; k++) { + g += V(k, j) * d_[k]; + e_[k] += V(k, j) * f; + } + e_[j] = g; + } + f = 0.0; + for (int j = 0; j < i; j++) { + e_[j] /= h; + f += e_[j] * d_[j]; + } + Real hh = f / (h + h); + for (int j = 0; j < i; j++) { + e_[j] -= hh * d_[j]; + } + for (int j = 0; j < i; j++) { + f = d_[j]; + g = e_[j]; + for (int k = j; k <= i-1; k++) { + V(k, j) -= (f * e_[k] + g * d_[k]); + } + d_[j] = V(i-1, j); + V(i, j) = 0.0; + } + } + d_[i] = h; + } + + // Accumulate transformations. + + for (int i = 0; i < n_-1; i++) { + V(n_-1, i) = V(i, i); + V(i, i) = 1.0; + Real h = d_[i+1]; + if (h != 0.0) { + for (int k = 0; k <= i; k++) { + d_[k] = V(k, i+1) / h; + } + for (int j = 0; j <= i; j++) { + Real g = 0.0; + for (int k = 0; k <= i; k++) { + g += V(k, i+1) * V(k, j); + } + for (int k = 0; k <= i; k++) { + V(k, j) -= g * d_[k]; + } + } + } + for (int k = 0; k <= i; k++) { + V(k, i+1) = 0.0; + } + } + for (int j = 0; j < n_; j++) { + d_[j] = V(n_-1, j); + V(n_-1, j) = 0.0; + } + V(n_-1, n_-1) = 1.0; + e_[0] = 0.0; +} + +template void EigenvalueDecomposition::Tql2() { + // This is derived from the Algol procedures tql2, by + // Bowdler, Martin, Reinsch, and Wilkinson, Handbook for + // Auto. Comp., Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + for (int i = 1; i < n_; i++) { + e_[i-1] = e_[i]; + } + e_[n_-1] = 0.0; + + Real f = 0.0; + Real tst1 = 0.0; + Real eps = std::numeric_limits::epsilon(); + for (int l = 0; l < n_; l++) { + + // Find small subdiagonal element + + tst1 = std::max(tst1, std::abs(d_[l]) + std::abs(e_[l])); + int m = l; + while (m < n_) { + if (std::abs(e_[m]) <= eps*tst1) { + break; + } + m++; + } + + // If m == l, d_[l] is an eigenvalue, + // otherwise, iterate. + + if (m > l) { + int iter = 0; + do { + iter = iter + 1; // (Could check iteration count here.) + + // Compute implicit shift + + Real g = d_[l]; + Real p = (d_[l+1] - g) / (2.0 *e_[l]); + Real r = Hypot(p, static_cast(1.0)); // This is a Kaldi version of hypot that works with templates. + if (p < 0) { + r = -r; + } + d_[l] =e_[l] / (p + r); + d_[l+1] =e_[l] * (p + r); + Real dl1 = d_[l+1]; + Real h = g - d_[l]; + for (int i = l+2; i < n_; i++) { + d_[i] -= h; + } + f = f + h; + + // Implicit QL transformation. + + p = d_[m]; + Real c = 1.0; + Real c2 = c; + Real c3 = c; + Real el1 =e_[l+1]; + Real s = 0.0; + Real s2 = 0.0; + for (int i = m-1; i >= l; i--) { + c3 = c2; + c2 = c; + s2 = s; + g = c *e_[i]; + h = c * p; + r = Hypot(p, e_[i]); // This is a Kaldi version of Hypot that works with templates. + e_[i+1] = s * r; + s =e_[i] / r; + c = p / r; + p = c * d_[i] - s * g; + d_[i+1] = h + s * (c * g + s * d_[i]); + + // Accumulate transformation. + + for (int k = 0; k < n_; k++) { + h = V(k, i+1); + V(k, i+1) = s * V(k, i) + c * h; + V(k, i) = c * V(k, i) - s * h; + } + } + p = -s * s2 * c3 * el1 *e_[l] / dl1; + e_[l] = s * p; + d_[l] = c * p; + + // Check for convergence. + + } while (std::abs(e_[l]) > eps*tst1); + } + d_[l] = d_[l] + f; + e_[l] = 0.0; + } + + // Sort eigenvalues and corresponding vectors. + + for (int i = 0; i < n_-1; i++) { + int k = i; + Real p = d_[i]; + for (int j = i+1; j < n_; j++) { + if (d_[j] < p) { + k = j; + p = d_[j]; + } + } + if (k != i) { + d_[k] = d_[i]; + d_[i] = p; + for (int j = 0; j < n_; j++) { + p = V(j, i); + V(j, i) = V(j, k); + V(j, k) = p; + } + } + } +} + +template +void EigenvalueDecomposition::Orthes() { + + // This is derived from the Algol procedures orthes and ortran, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutines in EISPACK. + + int low = 0; + int high = n_-1; + + for (int m = low+1; m <= high-1; m++) { + + // Scale column. + + Real scale = 0.0; + for (int i = m; i <= high; i++) { + scale = scale + std::abs(H(i, m-1)); + } + if (scale != 0.0) { + + // Compute Householder transformation. + + Real h = 0.0; + for (int i = high; i >= m; i--) { + ort_[i] = H(i, m-1)/scale; + h += ort_[i] * ort_[i]; + } + Real g = std::sqrt(h); + if (ort_[m] > 0) { + g = -g; + } + h = h - ort_[m] * g; + ort_[m] = ort_[m] - g; + + // Apply Householder similarity transformation + // H = (I-u*u'/h)*H*(I-u*u')/h) + + for (int j = m; j < n_; j++) { + Real f = 0.0; + for (int i = high; i >= m; i--) { + f += ort_[i]*H(i, j); + } + f = f/h; + for (int i = m; i <= high; i++) { + H(i, j) -= f*ort_[i]; + } + } + + for (int i = 0; i <= high; i++) { + Real f = 0.0; + for (int j = high; j >= m; j--) { + f += ort_[j]*H(i, j); + } + f = f/h; + for (int j = m; j <= high; j++) { + H(i, j) -= f*ort_[j]; + } + } + ort_[m] = scale*ort_[m]; + H(m, m-1) = scale*g; + } + } + + // Accumulate transformations (Algol's ortran). + + for (int i = 0; i < n_; i++) { + for (int j = 0; j < n_; j++) { + V(i, j) = (i == j ? 1.0 : 0.0); + } + } + + for (int m = high-1; m >= low+1; m--) { + if (H(m, m-1) != 0.0) { + for (int i = m+1; i <= high; i++) { + ort_[i] = H(i, m-1); + } + for (int j = m; j <= high; j++) { + Real g = 0.0; + for (int i = m; i <= high; i++) { + g += ort_[i] * V(i, j); + } + // Double division avoids possible underflow + g = (g / ort_[m]) / H(m, m-1); + for (int i = m; i <= high; i++) { + V(i, j) += g * ort_[i]; + } + } + } + } +} + +template void EigenvalueDecomposition::Hqr2() { + // This is derived from the Algol procedure hqr2, + // by Martin and Wilkinson, Handbook for Auto. Comp., + // Vol.ii-Linear Algebra, and the corresponding + // Fortran subroutine in EISPACK. + + int nn = n_; + int n = nn-1; + int low = 0; + int high = nn-1; + Real eps = std::numeric_limits::epsilon(); + Real exshift = 0.0; + Real p = 0, q = 0, r = 0, s = 0, z=0, t, w, x, y; + + // Store roots isolated by balanc and compute matrix norm + + Real norm = 0.0; + for (int i = 0; i < nn; i++) { + if (i < low || i > high) { + d_[i] = H(i, i); + e_[i] = 0.0; + } + for (int j = std::max(i-1, 0); j < nn; j++) { + norm = norm + std::abs(H(i, j)); + } + } + + // Outer loop over eigenvalue index + + int iter = 0; + while (n >= low) { + + // Look for single small sub-diagonal element + + int l = n; + while (l > low) { + s = std::abs(H(l-1, l-1)) + std::abs(H(l, l)); + if (s == 0.0) { + s = norm; + } + if (std::abs(H(l, l-1)) < eps * s) { + break; + } + l--; + } + + // Check for convergence + // One root found + + if (l == n) { + H(n, n) = H(n, n) + exshift; + d_[n] = H(n, n); + e_[n] = 0.0; + n--; + iter = 0; + + // Two roots found + + } else if (l == n-1) { + w = H(n, n-1) * H(n-1, n); + p = (H(n-1, n-1) - H(n, n)) / 2.0; + q = p * p + w; + z = std::sqrt(std::abs(q)); + H(n, n) = H(n, n) + exshift; + H(n-1, n-1) = H(n-1, n-1) + exshift; + x = H(n, n); + + // Real pair + + if (q >= 0) { + if (p >= 0) { + z = p + z; + } else { + z = p - z; + } + d_[n-1] = x + z; + d_[n] = d_[n-1]; + if (z != 0.0) { + d_[n] = x - w / z; + } + e_[n-1] = 0.0; + e_[n] = 0.0; + x = H(n, n-1); + s = std::abs(x) + std::abs(z); + p = x / s; + q = z / s; + r = std::sqrt(p * p+q * q); + p = p / r; + q = q / r; + + // Row modification + + for (int j = n-1; j < nn; j++) { + z = H(n-1, j); + H(n-1, j) = q * z + p * H(n, j); + H(n, j) = q * H(n, j) - p * z; + } + + // Column modification + + for (int i = 0; i <= n; i++) { + z = H(i, n-1); + H(i, n-1) = q * z + p * H(i, n); + H(i, n) = q * H(i, n) - p * z; + } + + // Accumulate transformations + + for (int i = low; i <= high; i++) { + z = V(i, n-1); + V(i, n-1) = q * z + p * V(i, n); + V(i, n) = q * V(i, n) - p * z; + } + + // Complex pair + + } else { + d_[n-1] = x + p; + d_[n] = x + p; + e_[n-1] = z; + e_[n] = -z; + } + n = n - 2; + iter = 0; + + // No convergence yet + + } else { + + // Form shift + + x = H(n, n); + y = 0.0; + w = 0.0; + if (l < n) { + y = H(n-1, n-1); + w = H(n, n-1) * H(n-1, n); + } + + // Wilkinson's original ad hoc shift + + if (iter == 10) { + exshift += x; + for (int i = low; i <= n; i++) { + H(i, i) -= x; + } + s = std::abs(H(n, n-1)) + std::abs(H(n-1, n-2)); + x = y = 0.75 * s; + w = -0.4375 * s * s; + } + + // MATLAB's new ad hoc shift + + if (iter == 30) { + s = (y - x) / 2.0; + s = s * s + w; + if (s > 0) { + s = std::sqrt(s); + if (y < x) { + s = -s; + } + s = x - w / ((y - x) / 2.0 + s); + for (int i = low; i <= n; i++) { + H(i, i) -= s; + } + exshift += s; + x = y = w = 0.964; + } + } + + iter = iter + 1; // (Could check iteration count here.) + + // Look for two consecutive small sub-diagonal elements + + int m = n-2; + while (m >= l) { + z = H(m, m); + r = x - z; + s = y - z; + p = (r * s - w) / H(m+1, m) + H(m, m+1); + q = H(m+1, m+1) - z - r - s; + r = H(m+2, m+1); + s = std::abs(p) + std::abs(q) + std::abs(r); + p = p / s; + q = q / s; + r = r / s; + if (m == l) { + break; + } + if (std::abs(H(m, m-1)) * (std::abs(q) + std::abs(r)) < + eps * (std::abs(p) * (std::abs(H(m-1, m-1)) + std::abs(z) + + std::abs(H(m+1, m+1))))) { + break; + } + m--; + } + + for (int i = m+2; i <= n; i++) { + H(i, i-2) = 0.0; + if (i > m+2) { + H(i, i-3) = 0.0; + } + } + + // Double QR step involving rows l:n and columns m:n + + for (int k = m; k <= n-1; k++) { + bool notlast = (k != n-1); + if (k != m) { + p = H(k, k-1); + q = H(k+1, k-1); + r = (notlast ? H(k+2, k-1) : 0.0); + x = std::abs(p) + std::abs(q) + std::abs(r); + if (x != 0.0) { + p = p / x; + q = q / x; + r = r / x; + } + } + if (x == 0.0) { + break; + } + s = std::sqrt(p * p + q * q + r * r); + if (p < 0) { + s = -s; + } + if (s != 0) { + if (k != m) { + H(k, k-1) = -s * x; + } else if (l != m) { + H(k, k-1) = -H(k, k-1); + } + p = p + s; + x = p / s; + y = q / s; + z = r / s; + q = q / p; + r = r / p; + + // Row modification + + for (int j = k; j < nn; j++) { + p = H(k, j) + q * H(k+1, j); + if (notlast) { + p = p + r * H(k+2, j); + H(k+2, j) = H(k+2, j) - p * z; + } + H(k, j) = H(k, j) - p * x; + H(k+1, j) = H(k+1, j) - p * y; + } + + // Column modification + + for (int i = 0; i <= std::min(n, k+3); i++) { + p = x * H(i, k) + y * H(i, k+1); + if (notlast) { + p = p + z * H(i, k+2); + H(i, k+2) = H(i, k+2) - p * r; + } + H(i, k) = H(i, k) - p; + H(i, k+1) = H(i, k+1) - p * q; + } + + // Accumulate transformations + + for (int i = low; i <= high; i++) { + p = x * V(i, k) + y * V(i, k+1); + if (notlast) { + p = p + z * V(i, k+2); + V(i, k+2) = V(i, k+2) - p * r; + } + V(i, k) = V(i, k) - p; + V(i, k+1) = V(i, k+1) - p * q; + } + } // (s != 0) + } // k loop + } // check convergence + } // while (n >= low) + + // Backsubstitute to find vectors of upper triangular form + + if (norm == 0.0) { + return; + } + + for (n = nn-1; n >= 0; n--) { + p = d_[n]; + q = e_[n]; + + // Real vector + + if (q == 0) { + int l = n; + H(n, n) = 1.0; + for (int i = n-1; i >= 0; i--) { + w = H(i, i) - p; + r = 0.0; + for (int j = l; j <= n; j++) { + r = r + H(i, j) * H(j, n); + } + if (e_[i] < 0.0) { + z = w; + s = r; + } else { + l = i; + if (e_[i] == 0.0) { + if (w != 0.0) { + H(i, n) = -r / w; + } else { + H(i, n) = -r / (eps * norm); + } + + // Solve real equations + + } else { + x = H(i, i+1); + y = H(i+1, i); + q = (d_[i] - p) * (d_[i] - p) +e_[i] *e_[i]; + t = (x * s - z * r) / q; + H(i, n) = t; + if (std::abs(x) > std::abs(z)) { + H(i+1, n) = (-r - w * t) / x; + } else { + H(i+1, n) = (-s - y * t) / z; + } + } + + // Overflow control + + t = std::abs(H(i, n)); + if ((eps * t) * t > 1) { + for (int j = i; j <= n; j++) { + H(j, n) = H(j, n) / t; + } + } + } + } + + // Complex vector + + } else if (q < 0) { + int l = n-1; + + // Last vector component imaginary so matrix is triangular + + if (std::abs(H(n, n-1)) > std::abs(H(n-1, n))) { + H(n-1, n-1) = q / H(n, n-1); + H(n-1, n) = -(H(n, n) - p) / H(n, n-1); + } else { + Real cdivr, cdivi; + cdiv(0.0, -H(n-1, n), H(n-1, n-1)-p, q, &cdivr, &cdivi); + H(n-1, n-1) = cdivr; + H(n-1, n) = cdivi; + } + H(n, n-1) = 0.0; + H(n, n) = 1.0; + for (int i = n-2; i >= 0; i--) { + Real ra, sa, vr, vi; + ra = 0.0; + sa = 0.0; + for (int j = l; j <= n; j++) { + ra = ra + H(i, j) * H(j, n-1); + sa = sa + H(i, j) * H(j, n); + } + w = H(i, i) - p; + + if (e_[i] < 0.0) { + z = w; + r = ra; + s = sa; + } else { + l = i; + if (e_[i] == 0) { + Real cdivr, cdivi; + cdiv(-ra, -sa, w, q, &cdivr, &cdivi); + H(i, n-1) = cdivr; + H(i, n) = cdivi; + } else { + Real cdivr, cdivi; + // Solve complex equations + + x = H(i, i+1); + y = H(i+1, i); + vr = (d_[i] - p) * (d_[i] - p) +e_[i] *e_[i] - q * q; + vi = (d_[i] - p) * 2.0 * q; + if (vr == 0.0 && vi == 0.0) { + vr = eps * norm * (std::abs(w) + std::abs(q) + + std::abs(x) + std::abs(y) + std::abs(z)); + } + cdiv(x*r-z*ra+q*sa, x*s-z*sa-q*ra, vr, vi, &cdivr, &cdivi); + H(i, n-1) = cdivr; + H(i, n) = cdivi; + if (std::abs(x) > (std::abs(z) + std::abs(q))) { + H(i+1, n-1) = (-ra - w * H(i, n-1) + q * H(i, n)) / x; + H(i+1, n) = (-sa - w * H(i, n) - q * H(i, n-1)) / x; + } else { + cdiv(-r-y*H(i, n-1), -s-y*H(i, n), z, q, &cdivr, &cdivi); + H(i+1, n-1) = cdivr; + H(i+1, n) = cdivi; + } + } + + // Overflow control + + t = std::max(std::abs(H(i, n-1)), std::abs(H(i, n))); + if ((eps * t) * t > 1) { + for (int j = i; j <= n; j++) { + H(j, n-1) = H(j, n-1) / t; + H(j, n) = H(j, n) / t; + } + } + } + } + } + } + + // Vectors of isolated roots + + for (int i = 0; i < nn; i++) { + if (i < low || i > high) { + for (int j = i; j < nn; j++) { + V(i, j) = H(i, j); + } + } + } + + // Back transformation to get eigenvectors of original matrix + + for (int j = nn-1; j >= low; j--) { + for (int i = low; i <= high; i++) { + z = 0.0; + for (int k = low; k <= std::min(j, high); k++) { + z = z + V(i, k) * H(k, j); + } + V(i, j) = z; + } + } +} + +template +EigenvalueDecomposition::EigenvalueDecomposition(const MatrixBase &A) { + KALDI_ASSERT(A.NumCols() == A.NumRows() && A.NumCols() >= 1); + n_ = A.NumRows(); + V_ = new Real[n_*n_]; + d_ = new Real[n_]; + e_ = new Real[n_]; + H_ = NULL; + ort_ = NULL; + if (A.IsSymmetric(0.0)) { + + for (int i = 0; i < n_; i++) + for (int j = 0; j < n_; j++) + V(i, j) = A(i, j); // Note that V(i, j) is a member function; A(i, j) is an operator + // of the matrix A. + // Tridiagonalize. + Tred2(); + + // Diagonalize. + Tql2(); + } else { + H_ = new Real[n_*n_]; + ort_ = new Real[n_]; + for (int i = 0; i < n_; i++) + for (int j = 0; j < n_; j++) + H(i, j) = A(i, j); // as before: H is member function, A(i, j) is operator of matrix. + + // Reduce to Hessenberg form. + Orthes(); + + // Reduce Hessenberg to real Schur form. + Hqr2(); + } +} + +template +EigenvalueDecomposition::~EigenvalueDecomposition() { + delete [] d_; + delete [] e_; + delete [] V_; + delete [] H_; + delete [] ort_; +} + +// see function MatrixBase::Eig in kaldi-matrix.cc + + +} // namespace kaldi + +#endif // KALDI_MATRIX_JAMA_EIG_H_ diff --git a/speechx/speechx/kaldi/matrix/jama-svd.h b/speechx/speechx/kaldi/matrix/jama-svd.h new file mode 100644 index 00000000..8304dac6 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/jama-svd.h @@ -0,0 +1,531 @@ +// matrix/jama-svd.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// This file consists of a port and modification of materials from +// JAMA: A Java Matrix Package +// under the following notice: This software is a cooperative product of +// The MathWorks and the National Institute of Standards and Technology (NIST) +// which has been released to the public. This notice and the original code are +// available at http://math.nist.gov/javanumerics/jama/domain.notice + + +#ifndef KALDI_MATRIX_JAMA_SVD_H_ +#define KALDI_MATRIX_JAMA_SVD_H_ 1 + + +#include "matrix/kaldi-matrix.h" +#include "matrix/sp-matrix.h" +#include "matrix/cblas-wrappers.h" + +namespace kaldi { + +#if defined(HAVE_ATLAS) || defined(USE_KALDI_SVD) +// using ATLAS as our math library, which doesn't have SVD -> need +// to implement it. + +// This routine is a modified form of jama_svd.h which is part of the TNT distribution. +// (originally comes from JAMA). + +/** Singular Value Decomposition. + *

+ * For an m-by-n matrix A with m >= n, the singular value decomposition is + * an m-by-n orthogonal matrix U, an n-by-n diagonal matrix S, and + * an n-by-n orthogonal matrix V so that A = U*S*V'. + *

+ * The singular values, sigma[k] = S(k, k), are ordered so that + * sigma[0] >= sigma[1] >= ... >= sigma[n-1]. + *

+ * The singular value decompostion always exists, so the constructor will + * never fail. The matrix condition number and the effective numerical + * rank can be computed from this decomposition. + + *

+ * (Adapted from JAMA, a Java Matrix Library, developed by jointly + * by the Mathworks and NIST; see http://math.nist.gov/javanumerics/jama). + */ + + +template +bool MatrixBase::JamaSvd(VectorBase *s_in, + MatrixBase *U_in, + MatrixBase *V_in) { // Destructive! + KALDI_ASSERT(s_in != NULL && U_in != this && V_in != this); + int wantu = (U_in != NULL), wantv = (V_in != NULL); + Matrix Utmp, Vtmp; + MatrixBase &U = (U_in ? *U_in : Utmp), &V = (V_in ? *V_in : Vtmp); + VectorBase &s = *s_in; + + int m = num_rows_, n = num_cols_; + KALDI_ASSERT(m>=n && m != 0 && n != 0); + if (wantu) KALDI_ASSERT((int)U.num_rows_ == m && (int)U.num_cols_ == n); + if (wantv) KALDI_ASSERT((int)V.num_rows_ == n && (int)V.num_cols_ == n); + KALDI_ASSERT((int)s.Dim() == n); // n<=m so n is min. + + int nu = n; + U.SetZero(); // make sure all zero. + Vector e(n); + Vector work(m); + MatrixBase &A(*this); + Real *adata = A.Data(), *workdata = work.Data(), *edata = e.Data(), + *udata = U.Data(), *vdata = V.Data(); + int astride = static_cast(A.Stride()), + ustride = static_cast(U.Stride()), + vstride = static_cast(V.Stride()); + int i = 0, j = 0, k = 0; + + // Reduce A to bidiagonal form, storing the diagonal elements + // in s and the super-diagonal elements in e. + + int nct = std::min(m-1, n); + int nrt = std::max(0, std::min(n-2, m)); + for (k = 0; k < std::max(nct, nrt); k++) { + if (k < nct) { + + // Compute the transformation for the k-th column and + // place the k-th diagonal in s(k). + // Compute 2-norm of k-th column without under/overflow. + s(k) = 0; + for (i = k; i < m; i++) { + s(k) = hypot(s(k), A(i, k)); + } + if (s(k) != 0.0) { + if (A(k, k) < 0.0) { + s(k) = -s(k); + } + for (i = k; i < m; i++) { + A(i, k) /= s(k); + } + A(k, k) += 1.0; + } + s(k) = -s(k); + } + for (j = k+1; j < n; j++) { + if ((k < nct) && (s(k) != 0.0)) { + + // Apply the transformation. + + Real t = cblas_Xdot(m - k, adata + astride*k + k, astride, + adata + astride*k + j, astride); + /*for (i = k; i < m; i++) { + t += adata[i*astride + k]*adata[i*astride + j]; // A(i, k)*A(i, j); // 3 + }*/ + t = -t/A(k, k); + cblas_Xaxpy(m - k, t, adata + k*astride + k, astride, + adata + k*astride + j, astride); + /*for (i = k; i < m; i++) { + adata[i*astride + j] += t*adata[i*astride + k]; // A(i, j) += t*A(i, k); // 5 + }*/ + } + + // Place the k-th row of A into e for the + // subsequent calculation of the row transformation. + + e(j) = A(k, j); + } + if (wantu & (k < nct)) { + + // Place the transformation in U for subsequent back + // multiplication. + + for (i = k; i < m; i++) { + U(i, k) = A(i, k); + } + } + if (k < nrt) { + + // Compute the k-th row transformation and place the + // k-th super-diagonal in e(k). + // Compute 2-norm without under/overflow. + e(k) = 0; + for (i = k+1; i < n; i++) { + e(k) = hypot(e(k), e(i)); + } + if (e(k) != 0.0) { + if (e(k+1) < 0.0) { + e(k) = -e(k); + } + for (i = k+1; i < n; i++) { + e(i) /= e(k); + } + e(k+1) += 1.0; + } + e(k) = -e(k); + if ((k+1 < m) & (e(k) != 0.0)) { + + // Apply the transformation. + + for (i = k+1; i < m; i++) { + work(i) = 0.0; + } + for (j = k+1; j < n; j++) { + for (i = k+1; i < m; i++) { + workdata[i] += edata[j] * adata[i*astride + j]; // work(i) += e(j)*A(i, j); // 5 + } + } + for (j = k+1; j < n; j++) { + Real t(-e(j)/e(k+1)); + cblas_Xaxpy(m - (k+1), t, workdata + (k+1), 1, + adata + (k+1)*astride + j, astride); + /* + for (i = k+1; i < m; i++) { + adata[i*astride + j] += t*workdata[i]; // A(i, j) += t*work(i); // 5 + }*/ + } + } + if (wantv) { + + // Place the transformation in V for subsequent + // back multiplication. + + for (i = k+1; i < n; i++) { + V(i, k) = e(i); + } + } + } + } + + // Set up the final bidiagonal matrix or order p. + + int p = std::min(n, m+1); + if (nct < n) { + s(nct) = A(nct, nct); + } + if (m < p) { + s(p-1) = 0.0; + } + if (nrt+1 < p) { + e(nrt) = A(nrt, p-1); + } + e(p-1) = 0.0; + + // If required, generate U. + + if (wantu) { + for (j = nct; j < nu; j++) { + for (i = 0; i < m; i++) { + U(i, j) = 0.0; + } + U(j, j) = 1.0; + } + for (k = nct-1; k >= 0; k--) { + if (s(k) != 0.0) { + for (j = k+1; j < nu; j++) { + Real t = cblas_Xdot(m - k, udata + k*ustride + k, ustride, udata + k*ustride + j, ustride); + //for (i = k; i < m; i++) { + // t += udata[i*ustride + k]*udata[i*ustride + j]; // t += U(i, k)*U(i, j); // 8 + // } + t = -t/U(k, k); + cblas_Xaxpy(m - k, t, udata + ustride*k + k, ustride, + udata + k*ustride + j, ustride); + /*for (i = k; i < m; i++) { + udata[i*ustride + j] += t*udata[i*ustride + k]; // U(i, j) += t*U(i, k); // 4 + }*/ + } + for (i = k; i < m; i++ ) { + U(i, k) = -U(i, k); + } + U(k, k) = 1.0 + U(k, k); + for (i = 0; i < k-1; i++) { + U(i, k) = 0.0; + } + } else { + for (i = 0; i < m; i++) { + U(i, k) = 0.0; + } + U(k, k) = 1.0; + } + } + } + + // If required, generate V. + + if (wantv) { + for (k = n-1; k >= 0; k--) { + if ((k < nrt) & (e(k) != 0.0)) { + for (j = k+1; j < nu; j++) { + Real t = cblas_Xdot(n - (k+1), vdata + (k+1)*vstride + k, vstride, + vdata + (k+1)*vstride + j, vstride); + /*Real t (0.0); + for (i = k+1; i < n; i++) { + t += vdata[i*vstride + k]*vdata[i*vstride + j]; // t += V(i, k)*V(i, j); // 7 + }*/ + t = -t/V(k+1, k); + cblas_Xaxpy(n - (k+1), t, vdata + (k+1)*vstride + k, vstride, + vdata + (k+1)*vstride + j, vstride); + /*for (i = k+1; i < n; i++) { + vdata[i*vstride + j] += t*vdata[i*vstride + k]; // V(i, j) += t*V(i, k); // 7 + }*/ + } + } + for (i = 0; i < n; i++) { + V(i, k) = 0.0; + } + V(k, k) = 1.0; + } + } + + // Main iteration loop for the singular values. + + int pp = p-1; + int iter = 0; + // note: -52.0 is from Jama code; the -23 is the extension + // to float, because mantissa length in (double, float) + // is (52, 23) bits respectively. + Real eps(pow(2.0, sizeof(Real) == 4 ? -23.0 : -52.0)); + // Note: the -966 was taken from Jama code, but the -120 is a guess + // of how to extend this to float... the exponent in double goes + // from -1022 .. 1023, and in float from -126..127. I'm not sure + // what the significance of 966 is, so -120 just represents a number + // that's a bit less negative than -126. If we get convergence + // failure in float only, this may mean that we have to make the + // -120 value less negative. + Real tiny(pow(2.0, sizeof(Real) == 4 ? -120.0: -966.0 )); + + while (p > 0) { + int k = 0; + int kase = 0; + + if (iter == 500 || iter == 750) { + KALDI_WARN << "Svd taking a long time: making convergence criterion less exact."; + eps = pow(static_cast(0.8), eps); + tiny = pow(static_cast(0.8), tiny); + } + if (iter > 1000) { + KALDI_WARN << "Svd not converging on matrix of size " << m << " by " <= -1; k--) { + if (k == -1) { + break; + } + if (std::abs(e(k)) <= + tiny + eps*(std::abs(s(k)) + std::abs(s(k+1)))) { + e(k) = 0.0; + break; + } + } + if (k == p-2) { + kase = 4; + } else { + int ks; + for (ks = p-1; ks >= k; ks--) { + if (ks == k) { + break; + } + Real t( (ks != p ? std::abs(e(ks)) : 0.) + + (ks != k+1 ? std::abs(e(ks-1)) : 0.)); + if (std::abs(s(ks)) <= tiny + eps*t) { + s(ks) = 0.0; + break; + } + } + if (ks == k) { + kase = 3; + } else if (ks == p-1) { + kase = 1; + } else { + kase = 2; + k = ks; + } + } + k++; + + // Perform the task indicated by kase. + + switch (kase) { + + // Deflate negligible s(p). + + case 1: { + Real f(e(p-2)); + e(p-2) = 0.0; + for (j = p-2; j >= k; j--) { + Real t( hypot(s(j), f)); + Real cs(s(j)/t); + Real sn(f/t); + s(j) = t; + if (j != k) { + f = -sn*e(j-1); + e(j-1) = cs*e(j-1); + } + if (wantv) { + for (i = 0; i < n; i++) { + t = cs*V(i, j) + sn*V(i, p-1); + V(i, p-1) = -sn*V(i, j) + cs*V(i, p-1); + V(i, j) = t; + } + } + } + } + break; + + // Split at negligible s(k). + + case 2: { + Real f(e(k-1)); + e(k-1) = 0.0; + for (j = k; j < p; j++) { + Real t(hypot(s(j), f)); + Real cs( s(j)/t); + Real sn(f/t); + s(j) = t; + f = -sn*e(j); + e(j) = cs*e(j); + if (wantu) { + for (i = 0; i < m; i++) { + t = cs*U(i, j) + sn*U(i, k-1); + U(i, k-1) = -sn*U(i, j) + cs*U(i, k-1); + U(i, j) = t; + } + } + } + } + break; + + // Perform one qr step. + + case 3: { + + // Calculate the shift. + + Real scale = std::max(std::max(std::max(std::max( + std::abs(s(p-1)), std::abs(s(p-2))), std::abs(e(p-2))), + std::abs(s(k))), std::abs(e(k))); + Real sp = s(p-1)/scale; + Real spm1 = s(p-2)/scale; + Real epm1 = e(p-2)/scale; + Real sk = s(k)/scale; + Real ek = e(k)/scale; + Real b = ((spm1 + sp)*(spm1 - sp) + epm1*epm1)/2.0; + Real c = (sp*epm1)*(sp*epm1); + Real shift = 0.0; + if ((b != 0.0) || (c != 0.0)) { + shift = std::sqrt(b*b + c); + if (b < 0.0) { + shift = -shift; + } + shift = c/(b + shift); + } + Real f = (sk + sp)*(sk - sp) + shift; + Real g = sk*ek; + + // Chase zeros. + + for (j = k; j < p-1; j++) { + Real t = hypot(f, g); + Real cs = f/t; + Real sn = g/t; + if (j != k) { + e(j-1) = t; + } + f = cs*s(j) + sn*e(j); + e(j) = cs*e(j) - sn*s(j); + g = sn*s(j+1); + s(j+1) = cs*s(j+1); + if (wantv) { + cblas_Xrot(n, vdata + j, vstride, vdata + j+1, vstride, cs, sn); + /*for (i = 0; i < n; i++) { + t = cs*vdata[i*vstride + j] + sn*vdata[i*vstride + j+1]; // t = cs*V(i, j) + sn*V(i, j+1); // 13 + vdata[i*vstride + j+1] = -sn*vdata[i*vstride + j] + cs*vdata[i*vstride + j+1]; // V(i, j+1) = -sn*V(i, j) + cs*V(i, j+1); // 5 + vdata[i*vstride + j] = t; // V(i, j) = t; // 4 + }*/ + } + t = hypot(f, g); + cs = f/t; + sn = g/t; + s(j) = t; + f = cs*e(j) + sn*s(j+1); + s(j+1) = -sn*e(j) + cs*s(j+1); + g = sn*e(j+1); + e(j+1) = cs*e(j+1); + if (wantu && (j < m-1)) { + cblas_Xrot(m, udata + j, ustride, udata + j+1, ustride, cs, sn); + /*for (i = 0; i < m; i++) { + t = cs*udata[i*ustride + j] + sn*udata[i*ustride + j+1]; // t = cs*U(i, j) + sn*U(i, j+1); // 7 + udata[i*ustride + j+1] = -sn*udata[i*ustride + j] +cs*udata[i*ustride + j+1]; // U(i, j+1) = -sn*U(i, j) + cs*U(i, j+1); // 8 + udata[i*ustride + j] = t; // U(i, j) = t; // 1 + }*/ + } + } + e(p-2) = f; + iter = iter + 1; + } + break; + + // Convergence. + + case 4: { + + // Make the singular values positive. + + if (s(k) <= 0.0) { + s(k) = (s(k) < 0.0 ? -s(k) : 0.0); + if (wantv) { + for (i = 0; i <= pp; i++) { + V(i, k) = -V(i, k); + } + } + } + + // Order the singular values. + + while (k < pp) { + if (s(k) >= s(k+1)) { + break; + } + Real t = s(k); + s(k) = s(k+1); + s(k+1) = t; + if (wantv && (k < n-1)) { + for (i = 0; i < n; i++) { + t = V(i, k+1); V(i, k+1) = V(i, k); V(i, k) = t; + } + } + if (wantu && (k < m-1)) { + for (i = 0; i < m; i++) { + t = U(i, k+1); U(i, k+1) = U(i, k); U(i, k) = t; + } + } + k++; + } + iter = 0; + p--; + } + break; + } + } + return true; +} + +#endif // defined(HAVE_ATLAS) || defined(USE_KALDI_SVD) + +} // namespace kaldi + +#endif // KALDI_MATRIX_JAMA_SVD_H_ diff --git a/speechx/speechx/kaldi/matrix/kaldi-blas.h b/speechx/speechx/kaldi/matrix/kaldi-blas.h new file mode 100644 index 00000000..b08d8c51 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/kaldi-blas.h @@ -0,0 +1,133 @@ +// matrix/kaldi-blas.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_KALDI_BLAS_H_ +#define KALDI_MATRIX_KALDI_BLAS_H_ + +// This file handles the #includes for BLAS, LAPACK and so on. +// It manipulates the declarations into a common format that kaldi can handle. +// However, the kaldi code will check whether HAVE_ATLAS is defined as that +// code is called a bit differently from CLAPACK that comes from other sources. + +// There are three alternatives: +// (i) you have ATLAS, which includes the ATLAS implementation of CBLAS +// plus a subset of CLAPACK (but with clapack_ in the function declarations). +// In this case, define HAVE_ATLAS and make sure the relevant directories are +// in the include path. + +// (ii) you have CBLAS (some implementation thereof) plus CLAPACK. +// In this case, define HAVE_CLAPACK. +// [Since CLAPACK depends on BLAS, the presence of BLAS is implicit]. + +// (iii) you have the MKL library, which includes CLAPACK and CBLAS. + +// Note that if we are using ATLAS, no Svd implementation is supplied, +// so we define HAVE_Svd to be zero and this directs our implementation to +// supply its own "by hand" implementation which is based on TNT code. + + + +#define HAVE_OPENBLAS + +#if (defined(HAVE_CLAPACK) && (defined(HAVE_ATLAS) || defined(HAVE_MKL))) \ + || (defined(HAVE_ATLAS) && defined(HAVE_MKL)) +#error "Do not define more than one of HAVE_CLAPACK, HAVE_ATLAS and HAVE_MKL" +#endif + +#ifdef HAVE_ATLAS + extern "C" { + #include "cblas.h" + #include "clapack.h" + } +#elif defined(HAVE_CLAPACK) + #ifdef __APPLE__ + #ifndef __has_extension + #define __has_extension(x) 0 + #endif + #define vImage_Utilities_h + #define vImage_CVUtilities_h + #include + typedef __CLPK_integer integer; + typedef __CLPK_logical logical; + typedef __CLPK_real real; + typedef __CLPK_doublereal doublereal; + typedef __CLPK_complex complex; + typedef __CLPK_doublecomplex doublecomplex; + typedef __CLPK_ftnlen ftnlen; + #else + extern "C" { + // May be in /usr/[local]/include if installed; else this uses the one + // from the tools/CLAPACK_include directory. + #include + #include + #include + + // get rid of macros from f2c.h -- these are dangerous. + #undef abs + #undef dabs + #undef min + #undef max + #undef dmin + #undef dmax + #undef bit_test + #undef bit_clear + #undef bit_set + } + #endif +#elif defined(HAVE_MKL) + extern "C" { + #include + } +#elif defined(HAVE_OPENBLAS) + // getting cblas.h and lapacke.h from /. + // putting in "" not <> to search -I before system libraries. + #include "third_party/openblas/cblas.h" + #include "third_party/openblas/lapacke.h" + #undef I + #undef complex + // get rid of macros from f2c.h -- these are dangerous. + #undef abs + #undef dabs + #undef min + #undef max + #undef dmin + #undef dmax + #undef bit_test + #undef bit_clear + #undef bit_set +#else + #error "You need to define (using the preprocessor) either HAVE_CLAPACK or HAVE_ATLAS or HAVE_MKL (but not more than one)" +#endif + +#ifdef HAVE_OPENBLAS +typedef int KaldiBlasInt; // try int. +#endif +#ifdef HAVE_CLAPACK +typedef integer KaldiBlasInt; +#endif +#ifdef HAVE_MKL +typedef MKL_INT KaldiBlasInt; +#endif + +#ifdef HAVE_ATLAS +// in this case there is no need for KaldiBlasInt-- this typedef is only needed +// for Svd code which is not included in ATLAS (we re-implement it). +#endif + + +#endif // KALDI_MATRIX_KALDI_BLAS_H_ diff --git a/speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h b/speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h new file mode 100644 index 00000000..c2ff0079 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/kaldi-matrix-inl.h @@ -0,0 +1,63 @@ +// matrix/kaldi-matrix-inl.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_MATRIX_INL_H_ +#define KALDI_MATRIX_KALDI_MATRIX_INL_H_ 1 + +#include "matrix/kaldi-vector.h" + +namespace kaldi { + +/// Empty constructor +template +Matrix::Matrix(): MatrixBase(NULL, 0, 0, 0) { } + + +template<> +template<> +void MatrixBase::AddVecVec(const float alpha, const VectorBase &ra, const VectorBase &rb); + +template<> +template<> +void MatrixBase::AddVecVec(const double alpha, const VectorBase &ra, const VectorBase &rb); + +template +inline std::ostream & operator << (std::ostream & os, const MatrixBase & M) { + M.Write(os, false); + return os; +} + +template +inline std::istream & operator >> (std::istream & is, Matrix & M) { + M.Read(is, false); + return is; +} + + +template +inline std::istream & operator >> (std::istream & is, MatrixBase & M) { + M.Read(is, false); + return is; +} + +}// namespace kaldi + + +#endif // KALDI_MATRIX_KALDI_MATRIX_INL_H_ + diff --git a/speechx/speechx/kaldi/matrix/kaldi-matrix.cc b/speechx/speechx/kaldi/matrix/kaldi-matrix.cc new file mode 100644 index 00000000..faf23cdf --- /dev/null +++ b/speechx/speechx/kaldi/matrix/kaldi-matrix.cc @@ -0,0 +1,3103 @@ +// matrix/kaldi-matrix.cc + +// Copyright 2009-2011 Lukas Burget; Ondrej Glembek; Go Vivace Inc.; +// Microsoft Corporation; Saarland University; +// Yanmin Qian; Petr Schwarz; Jan Silovsky; +// Haihua Xu +// 2017 Shiyin Kang +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "matrix/kaldi-matrix.h" +#include "matrix/sp-matrix.h" +#include "matrix/jama-svd.h" +#include "matrix/jama-eig.h" +#include "matrix/compressed-matrix.h" +#include "matrix/sparse-matrix.h" + +static_assert(int(kaldi::kNoTrans) == int(CblasNoTrans) && int(kaldi::kTrans) == int(CblasTrans), + "kaldi::kNoTrans and kaldi::kTrans must be equal to the appropriate CBLAS library constants!"); + +namespace kaldi { + +template +void MatrixBase::Invert(Real *log_det, Real *det_sign, + bool inverse_needed) { + KALDI_ASSERT(num_rows_ == num_cols_); + if (num_rows_ == 0) { + if (det_sign) *det_sign = 1; + if (log_det) *log_det = 0.0; + return; + } +#ifndef HAVE_ATLAS + KaldiBlasInt *pivot = new KaldiBlasInt[num_rows_]; + KaldiBlasInt M = num_rows_; + KaldiBlasInt N = num_cols_; + KaldiBlasInt LDA = stride_; + KaldiBlasInt result = -1; + KaldiBlasInt l_work = std::max(1, N); + Real *p_work; + void *temp; + if ((p_work = static_cast( + KALDI_MEMALIGN(16, sizeof(Real)*l_work, &temp))) == NULL) { + delete[] pivot; + throw std::bad_alloc(); + } + + clapack_Xgetrf2(&M, &N, data_, &LDA, pivot, &result); + const int pivot_offset = 1; +#else + int *pivot = new int[num_rows_]; + int result; + clapack_Xgetrf(num_rows_, num_cols_, data_, stride_, pivot, &result); + const int pivot_offset = 0; +#endif + KALDI_ASSERT(result >= 0 && "Call to CLAPACK sgetrf_ or ATLAS clapack_sgetrf " + "called with wrong arguments"); + if (result > 0) { + if (inverse_needed) { + KALDI_ERR << "Cannot invert: matrix is singular"; + } else { + if (log_det) *log_det = -std::numeric_limits::infinity(); + if (det_sign) *det_sign = 0; + delete[] pivot; +#ifndef HAVE_ATLAS + KALDI_MEMALIGN_FREE(p_work); +#endif + return; + } + } + if (det_sign != NULL) { + int sign = 1; + for (MatrixIndexT i = 0; i < num_rows_; i++) + if (pivot[i] != static_cast(i) + pivot_offset) sign *= -1; + *det_sign = sign; + } + if (log_det != NULL || det_sign != NULL) { // Compute log determinant. + if (log_det != NULL) *log_det = 0.0; + Real prod = 1.0; + for (MatrixIndexT i = 0; i < num_rows_; i++) { + prod *= (*this)(i, i); + if (i == num_rows_ - 1 || std::fabs(prod) < 1.0e-10 || + std::fabs(prod) > 1.0e+10) { + if (log_det != NULL) *log_det += kaldi::Log(std::fabs(prod)); + if (det_sign != NULL) *det_sign *= (prod > 0 ? 1.0 : -1.0); + prod = 1.0; + } + } + } +#ifndef HAVE_ATLAS + if (inverse_needed) clapack_Xgetri2(&M, data_, &LDA, pivot, p_work, &l_work, + &result); + delete[] pivot; + KALDI_MEMALIGN_FREE(p_work); +#else + if (inverse_needed) + clapack_Xgetri(num_rows_, data_, stride_, pivot, &result); + delete [] pivot; +#endif + KALDI_ASSERT(result == 0 && "Call to CLAPACK sgetri_ or ATLAS clapack_sgetri " + "called with wrong arguments"); +} + +template<> +template<> +void MatrixBase::AddVecVec(const float alpha, + const VectorBase &a, + const VectorBase &rb) { + KALDI_ASSERT(a.Dim() == num_rows_ && rb.Dim() == num_cols_); + cblas_Xger(a.Dim(), rb.Dim(), alpha, a.Data(), 1, rb.Data(), + 1, data_, stride_); +} + +template +template +void MatrixBase::AddVecVec(const Real alpha, + const VectorBase &a, + const VectorBase &b) { + KALDI_ASSERT(a.Dim() == num_rows_ && b.Dim() == num_cols_); + if (num_rows_ * num_cols_ > 100) { // It's probably worth it to allocate + // temporary vectors of the right type and use BLAS. + Vector temp_a(a), temp_b(b); + cblas_Xger(num_rows_, num_cols_, alpha, temp_a.Data(), 1, + temp_b.Data(), 1, data_, stride_); + } else { + const OtherReal *a_data = a.Data(), *b_data = b.Data(); + Real *row_data = data_; + for (MatrixIndexT i = 0; i < num_rows_; i++, row_data += stride_) { + BaseFloat alpha_ai = alpha * a_data[i]; + for (MatrixIndexT j = 0; j < num_cols_; j++) + row_data[j] += alpha_ai * b_data[j]; + } + } +} + +// instantiate the template above. +template +void MatrixBase::AddVecVec(const float alpha, + const VectorBase &a, + const VectorBase &b); +template +void MatrixBase::AddVecVec(const double alpha, + const VectorBase &a, + const VectorBase &b); + +template<> +template<> +void MatrixBase::AddVecVec(const double alpha, + const VectorBase &a, + const VectorBase &rb) { + KALDI_ASSERT(a.Dim() == num_rows_ && rb.Dim() == num_cols_); + if (num_rows_ == 0) return; + cblas_Xger(a.Dim(), rb.Dim(), alpha, a.Data(), 1, rb.Data(), + 1, data_, stride_); +} + +template +void MatrixBase::AddMatMat(const Real alpha, + const MatrixBase& A, + MatrixTransposeType transA, + const MatrixBase& B, + MatrixTransposeType transB, + const Real beta) { + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT(&A != this && &B != this); + if (num_rows_ == 0) return; + cblas_Xgemm(alpha, transA, A.data_, A.num_rows_, A.num_cols_, A.stride_, + transB, B.data_, B.stride_, beta, data_, num_rows_, num_cols_, stride_); + +} + +template +void MatrixBase::SetMatMatDivMat(const MatrixBase& A, + const MatrixBase& B, + const MatrixBase& C) { + KALDI_ASSERT(A.NumRows() == B.NumRows() && A.NumCols() == B.NumCols()); + KALDI_ASSERT(A.NumRows() == C.NumRows() && A.NumCols() == C.NumCols()); + for (int32 r = 0; r < A.NumRows(); r++) { // each frame... + for (int32 c = 0; c < A.NumCols(); c++) { + BaseFloat i = C(r, c), o = B(r, c), od = A(r, c), + id; + if (i != 0.0) { + id = od * (o / i); /// o / i is either zero or "scale". + } else { + id = od; /// Just imagine the scale was 1.0. This is somehow true in + /// expectation; anyway, this case should basically never happen so it doesn't + /// really matter. + } + (*this)(r, c) = id; + } + } +} + + +template +void MatrixBase::CopyLowerToUpper() { + KALDI_ASSERT(num_rows_ == num_cols_); + Real *data = data_; + MatrixIndexT num_rows = num_rows_, stride = stride_; + for (int32 i = 0; i < num_rows; i++) + for (int32 j = 0; j < i; j++) + data[j * stride + i ] = data[i * stride + j]; +} + + +template +void MatrixBase::CopyUpperToLower() { + KALDI_ASSERT(num_rows_ == num_cols_); + Real *data = data_; + MatrixIndexT num_rows = num_rows_, stride = stride_; + for (int32 i = 0; i < num_rows; i++) + for (int32 j = 0; j < i; j++) + data[i * stride + j] = data[j * stride + i]; +} + +template +void MatrixBase::SymAddMat2(const Real alpha, + const MatrixBase &A, + MatrixTransposeType transA, + Real beta) { + KALDI_ASSERT(num_rows_ == num_cols_ && + ((transA == kNoTrans && A.num_rows_ == num_rows_) || + (transA == kTrans && A.num_cols_ == num_cols_))); + KALDI_ASSERT(A.data_ != data_); + if (num_rows_ == 0) return; + + /// When the matrix dimension(this->num_rows_) is not less than 56 + /// and the transpose type transA == kTrans, the cblas_Xsyrk(...) + /// function will produce NaN in the output. This is a bug in the + /// ATLAS library. To overcome this, the AddMatMat function, which calls + /// cblas_Xgemm(...) rather than cblas_Xsyrk(...), is used in this special + /// sitation. + /// Wei Shi: Note this bug is observerd for single precision matrix + /// on a 64-bit machine +#ifdef HAVE_ATLAS + if (transA == kTrans && num_rows_ >= 56) { + this->AddMatMat(alpha, A, kTrans, A, kNoTrans, beta); + return; + } +#endif // HAVE_ATLAS + + MatrixIndexT A_other_dim = (transA == kNoTrans ? A.num_cols_ : A.num_rows_); + + // This function call is hard-coded to update the lower triangle. + cblas_Xsyrk(transA, num_rows_, A_other_dim, alpha, A.Data(), + A.Stride(), beta, this->data_, this->stride_); +} + + +template +void MatrixBase::AddMatSmat(const Real alpha, + const MatrixBase &A, + MatrixTransposeType transA, + const MatrixBase &B, + MatrixTransposeType transB, + const Real beta) { + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT(&A != this && &B != this); + + // We iterate over the columns of B. + + MatrixIndexT Astride = A.stride_, Bstride = B.stride_, stride = this->stride_, + Arows = A.num_rows_, Acols = A.num_cols_; + Real *data = this->data_, *Adata = A.data_, *Bdata = B.data_; + MatrixIndexT num_cols = this->num_cols_; + if (transB == kNoTrans) { + // Iterate over the columns of *this and of B. + for (MatrixIndexT c = 0; c < num_cols; c++) { + // for each column of *this, do + // [this column] = [alpha * A * this column of B] + [beta * this column] + Xgemv_sparsevec(transA, Arows, Acols, alpha, Adata, Astride, + Bdata + c, Bstride, beta, data + c, stride); + } + } else { + // Iterate over the columns of *this and the rows of B. + for (MatrixIndexT c = 0; c < num_cols; c++) { + // for each column of *this, do + // [this column] = [alpha * A * this row of B] + [beta * this column] + Xgemv_sparsevec(transA, Arows, Acols, alpha, Adata, Astride, + Bdata + (c * Bstride), 1, beta, data + c, stride); + } + } +} + +template +void MatrixBase::AddSmatMat(const Real alpha, + const MatrixBase &A, + MatrixTransposeType transA, + const MatrixBase &B, + MatrixTransposeType transB, + const Real beta) { + KALDI_ASSERT((transA == kNoTrans && transB == kNoTrans && A.num_cols_ == B.num_rows_ && A.num_rows_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kTrans && transB == kNoTrans && A.num_rows_ == B.num_rows_ && A.num_cols_ == num_rows_ && B.num_cols_ == num_cols_) + || (transA == kNoTrans && transB == kTrans && A.num_cols_ == B.num_cols_ && A.num_rows_ == num_rows_ && B.num_rows_ == num_cols_) + || (transA == kTrans && transB == kTrans && A.num_rows_ == B.num_cols_ && A.num_cols_ == num_rows_ && B.num_rows_ == num_cols_)); + KALDI_ASSERT(&A != this && &B != this); + + MatrixIndexT Astride = A.stride_, Bstride = B.stride_, stride = this->stride_, + Brows = B.num_rows_, Bcols = B.num_cols_; + MatrixTransposeType invTransB = (transB == kTrans ? kNoTrans : kTrans); + Real *data = this->data_, *Adata = A.data_, *Bdata = B.data_; + MatrixIndexT num_rows = this->num_rows_; + if (transA == kNoTrans) { + // Iterate over the rows of *this and of A. + for (MatrixIndexT r = 0; r < num_rows; r++) { + // for each row of *this, do + // [this row] = [alpha * (this row of A) * B^T] + [beta * this row] + Xgemv_sparsevec(invTransB, Brows, Bcols, alpha, Bdata, Bstride, + Adata + (r * Astride), 1, beta, data + (r * stride), 1); + } + } else { + // Iterate over the rows of *this and the columns of A. + for (MatrixIndexT r = 0; r < num_rows; r++) { + // for each row of *this, do + // [this row] = [alpha * (this column of A) * B^T] + [beta * this row] + Xgemv_sparsevec(invTransB, Brows, Bcols, alpha, Bdata, Bstride, + Adata + r, Astride, beta, data + (r * stride), 1); + } + } +} + +template +void MatrixBase::AddSpSp(const Real alpha, const SpMatrix &A_in, + const SpMatrix &B_in, const Real beta) { + MatrixIndexT sz = num_rows_; + KALDI_ASSERT(sz == num_cols_ && sz == A_in.NumRows() && sz == B_in.NumRows()); + + Matrix A(A_in), B(B_in); + // CblasLower or CblasUpper would work below as symmetric matrix is copied + // fully (to save work, we used the matrix constructor from SpMatrix). + // CblasLeft means A is on the left: C <-- alpha A B + beta C + if (sz == 0) return; + cblas_Xsymm(alpha, sz, A.data_, A.stride_, B.data_, B.stride_, beta, data_, stride_); +} + +template +void MatrixBase::AddMat(const Real alpha, const MatrixBase& A, + MatrixTransposeType transA) { + if (&A == this) { + if (transA == kNoTrans) { + Scale(alpha + 1.0); + } else { + KALDI_ASSERT(num_rows_ == num_cols_ && "AddMat: adding to self (transposed): not symmetric."); + Real *data = data_; + if (alpha == 1.0) { // common case-- handle separately. + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < row; col++) { + Real *lower = data + (row * stride_) + col, *upper = data + (col + * stride_) + row; + Real sum = *lower + *upper; + *lower = *upper = sum; + } + *(data + (row * stride_) + row) *= 2.0; // diagonal. + } + } else { + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < row; col++) { + Real *lower = data + (row * stride_) + col, *upper = data + (col + * stride_) + row; + Real lower_tmp = *lower; + *lower += alpha * *upper; + *upper += alpha * lower_tmp; + } + *(data + (row * stride_) + row) *= (1.0 + alpha); // diagonal. + } + } + } + } else { + int aStride = (int) A.stride_, stride = stride_; + Real *adata = A.data_, *data = data_; + if (transA == kNoTrans) { + KALDI_ASSERT(A.num_rows_ == num_rows_ && A.num_cols_ == num_cols_); + if (num_rows_ == 0) return; + for (MatrixIndexT row = 0; row < num_rows_; row++, adata += aStride, + data += stride) { + cblas_Xaxpy(num_cols_, alpha, adata, 1, data, 1); + } + } else { + KALDI_ASSERT(A.num_cols_ == num_rows_ && A.num_rows_ == num_cols_); + if (num_rows_ == 0) return; + for (MatrixIndexT row = 0; row < num_rows_; row++, adata++, data += stride) + cblas_Xaxpy(num_cols_, alpha, adata, aStride, data, 1); + } + } +} + +template +void MatrixBase::AddSmat(Real alpha, const SparseMatrix &A, + MatrixTransposeType trans) { + if (trans == kNoTrans) { + KALDI_ASSERT(NumRows() == A.NumRows()); + KALDI_ASSERT(NumCols() == A.NumCols()); + MatrixIndexT a_num_rows = A.NumRows(); + for (MatrixIndexT i = 0; i < a_num_rows; ++i) { + const SparseVector &row = A.Row(i); + MatrixIndexT num_elems = row.NumElements(); + for (MatrixIndexT id = 0; id < num_elems; ++id) { + (*this)(i, row.GetElement(id).first) += alpha + * row.GetElement(id).second; + } + } + } else { + KALDI_ASSERT(NumRows() == A.NumCols()); + KALDI_ASSERT(NumCols() == A.NumRows()); + MatrixIndexT a_num_rows = A.NumRows(); + for (MatrixIndexT i = 0; i < a_num_rows; ++i) { + const SparseVector &row = A.Row(i); + MatrixIndexT num_elems = row.NumElements(); + for (MatrixIndexT id = 0; id < num_elems; ++id) { + (*this)(row.GetElement(id).first, i) += alpha + * row.GetElement(id).second; + } + } + } +} + +template +void MatrixBase::AddSmatMat(Real alpha, const SparseMatrix &A, + MatrixTransposeType transA, + const MatrixBase &B, Real beta) { + if (transA == kNoTrans) { + KALDI_ASSERT(NumRows() == A.NumRows()); + KALDI_ASSERT(NumCols() == B.NumCols()); + KALDI_ASSERT(A.NumCols() == B.NumRows()); + + this->Scale(beta); + MatrixIndexT a_num_rows = A.NumRows(), + this_num_cols = this->NumCols(); + for (MatrixIndexT i = 0; i < a_num_rows; ++i) { + Real *this_row_i = this->RowData(i); + const SparseVector &A_row_i = A.Row(i); + MatrixIndexT num_elems = A_row_i.NumElements(); + for (MatrixIndexT e = 0; e < num_elems; ++e) { + const std::pair &p = A_row_i.GetElement(e); + MatrixIndexT k = p.first; + Real alpha_A_ik = alpha * p.second; + const Real *b_row_k = B.RowData(k); + cblas_Xaxpy(this_num_cols, alpha_A_ik, b_row_k, 1, + this_row_i, 1); + //for (MatrixIndexT j = 0; j < this_num_cols; ++j) + // this_row_i[j] += alpha_A_ik * b_row_k[j]; + } + } + } else { + KALDI_ASSERT(NumRows() == A.NumCols()); + KALDI_ASSERT(NumCols() == B.NumCols()); + KALDI_ASSERT(A.NumRows() == B.NumRows()); + + this->Scale(beta); + Matrix buf(NumRows(), NumCols(), kSetZero); + MatrixIndexT a_num_rows = A.NumRows(), + this_num_cols = this->NumCols(); + for (int k = 0; k < a_num_rows; ++k) { + const Real *b_row_k = B.RowData(k); + const SparseVector &A_row_k = A.Row(k); + MatrixIndexT num_elems = A_row_k.NumElements(); + for (MatrixIndexT e = 0; e < num_elems; ++e) { + const std::pair &p = A_row_k.GetElement(e); + MatrixIndexT i = p.first; + Real alpha_A_ki = alpha * p.second; + Real *this_row_i = this->RowData(i); + cblas_Xaxpy(this_num_cols, alpha_A_ki, b_row_k, 1, + this_row_i, 1); + //for (MatrixIndexT j = 0; j < this_num_cols; ++j) + // this_row_i[j] += alpha_A_ki * b_row_k[j]; + } + } + } +} + +template +void MatrixBase::AddMatSmat(Real alpha, const MatrixBase &A, + const SparseMatrix &B, + MatrixTransposeType transB, Real beta) { + if (transB == kNoTrans) { + KALDI_ASSERT(NumRows() == A.NumRows()); + KALDI_ASSERT(NumCols() == B.NumCols()); + KALDI_ASSERT(A.NumCols() == B.NumRows()); + + this->Scale(beta); + MatrixIndexT b_num_rows = B.NumRows(), + this_num_rows = this->NumRows(); + // Iterate over the rows of sparse matrix B and columns of A. + for (MatrixIndexT k = 0; k < b_num_rows; ++k) { + const SparseVector &B_row_k = B.Row(k); + MatrixIndexT num_elems = B_row_k.NumElements(); + const Real *a_col_k = A.Data() + k; + for (MatrixIndexT e = 0; e < num_elems; ++e) { + const std::pair &p = B_row_k.GetElement(e); + MatrixIndexT j = p.first; + Real alpha_B_kj = alpha * p.second; + Real *this_col_j = this->Data() + j; + // Add to entire 'j'th column of *this at once using cblas_Xaxpy. + // pass stride to write a colmun as matrices are stored in row major order. + cblas_Xaxpy(this_num_rows, alpha_B_kj, a_col_k, A.stride_, + this_col_j, this->stride_); + //for (MatrixIndexT i = 0; i < this_num_rows; ++i) + // this_col_j[i*this->stride_] += alpha_B_kj * a_col_k[i*A.stride_]; + } + } + } else { + KALDI_ASSERT(NumRows() == A.NumRows()); + KALDI_ASSERT(NumCols() == B.NumRows()); + KALDI_ASSERT(A.NumCols() == B.NumCols()); + + this->Scale(beta); + MatrixIndexT b_num_rows = B.NumRows(), + this_num_rows = this->NumRows(); + // Iterate over the rows of sparse matrix B and columns of *this. + for (MatrixIndexT j = 0; j < b_num_rows; ++j) { + const SparseVector &B_row_j = B.Row(j); + MatrixIndexT num_elems = B_row_j.NumElements(); + Real *this_col_j = this->Data() + j; + for (MatrixIndexT e = 0; e < num_elems; ++e) { + const std::pair &p = B_row_j.GetElement(e); + MatrixIndexT k = p.first; + Real alpha_B_jk = alpha * p.second; + const Real *a_col_k = A.Data() + k; + // Add to entire 'j'th column of *this at once using cblas_Xaxpy. + // pass stride to write a column as matrices are stored in row major order. + cblas_Xaxpy(this_num_rows, alpha_B_jk, a_col_k, A.stride_, + this_col_j, this->stride_); + //for (MatrixIndexT i = 0; i < this_num_rows; ++i) + // this_col_j[i*this->stride_] += alpha_B_jk * a_col_k[i*A.stride_]; + } + } + } +} + +template +template +void MatrixBase::AddSp(const Real alpha, const SpMatrix &S) { + KALDI_ASSERT(S.NumRows() == NumRows() && S.NumRows() == NumCols()); + Real *data = data_; const OtherReal *sdata = S.Data(); + MatrixIndexT num_rows = NumRows(), stride = Stride(); + for (MatrixIndexT i = 0; i < num_rows; i++) { + for (MatrixIndexT j = 0; j < i; j++, sdata++) { + data[i*stride + j] += alpha * *sdata; + data[j*stride + i] += alpha * *sdata; + } + data[i*stride + i] += alpha * *sdata++; + } +} + +// instantiate the template above. +template +void MatrixBase::AddSp(const float alpha, const SpMatrix &S); +template +void MatrixBase::AddSp(const double alpha, const SpMatrix &S); +template +void MatrixBase::AddSp(const float alpha, const SpMatrix &S); +template +void MatrixBase::AddSp(const double alpha, const SpMatrix &S); + + +template +void MatrixBase::AddDiagVecMat( + const Real alpha, const VectorBase &v, + const MatrixBase &M, + MatrixTransposeType transM, + Real beta) { + if (beta != 1.0) this->Scale(beta); + + if (transM == kNoTrans) { + KALDI_ASSERT(SameDim(*this, M)); + } else { + KALDI_ASSERT(M.NumRows() == NumCols() && M.NumCols() == NumRows()); + } + KALDI_ASSERT(v.Dim() == this->NumRows()); + + MatrixIndexT M_row_stride = M.Stride(), M_col_stride = 1, stride = stride_, + num_rows = num_rows_, num_cols = num_cols_; + if (transM == kTrans) std::swap(M_row_stride, M_col_stride); + Real *data = data_; + const Real *Mdata = M.Data(), *vdata = v.Data(); + if (num_rows_ == 0) return; + for (MatrixIndexT i = 0; i < num_rows; i++, data += stride, Mdata += M_row_stride, vdata++) + cblas_Xaxpy(num_cols, alpha * *vdata, Mdata, M_col_stride, data, 1); +} + +template +void MatrixBase::AddMatDiagVec( + const Real alpha, + const MatrixBase &M, MatrixTransposeType transM, + VectorBase &v, + Real beta) { + + if (beta != 1.0) this->Scale(beta); + + if (transM == kNoTrans) { + KALDI_ASSERT(SameDim(*this, M)); + } else { + KALDI_ASSERT(M.NumRows() == NumCols() && M.NumCols() == NumRows()); + } + KALDI_ASSERT(v.Dim() == this->NumCols()); + + MatrixIndexT M_row_stride = M.Stride(), + M_col_stride = 1, + stride = stride_, + num_rows = num_rows_, + num_cols = num_cols_; + + if (transM == kTrans) + std::swap(M_row_stride, M_col_stride); + + Real *data = data_; + const Real *Mdata = M.Data(), *vdata = v.Data(); + if (num_rows_ == 0) return; + for (MatrixIndexT i = 0; i < num_rows; i++){ + for(MatrixIndexT j = 0; j < num_cols; j ++ ){ + data[i*stride + j] += alpha * vdata[j] * Mdata[i*M_row_stride + j*M_col_stride]; + } + } +} + +template +void MatrixBase::AddMatMatElements(const Real alpha, + const MatrixBase& A, + const MatrixBase& B, + const Real beta) { + KALDI_ASSERT(A.NumRows() == B.NumRows() && A.NumCols() == B.NumCols()); + KALDI_ASSERT(A.NumRows() == NumRows() && A.NumCols() == NumCols()); + Real *data = data_; + const Real *dataA = A.Data(); + const Real *dataB = B.Data(); + + for (MatrixIndexT i = 0; i < num_rows_; i++) { + for (MatrixIndexT j = 0; j < num_cols_; j++) { + data[j] = beta*data[j] + alpha*dataA[j]*dataB[j]; + } + data += Stride(); + dataA += A.Stride(); + dataB += B.Stride(); + } +} + +#if !defined(HAVE_ATLAS) && !defined(USE_KALDI_SVD) +// **************************************************************************** +// **************************************************************************** +template +void MatrixBase::LapackGesvd(VectorBase *s, MatrixBase *U_in, + MatrixBase *V_in) { + KALDI_ASSERT(s != NULL && U_in != this && V_in != this); + + Matrix tmpU, tmpV; + if (U_in == NULL) tmpU.Resize(this->num_rows_, 1); // work-space if U_in empty. + if (V_in == NULL) tmpV.Resize(1, this->num_cols_); // work-space if V_in empty. + + /// Impementation notes: + /// Lapack works in column-order, therefore the dimensions of *this are + /// swapped as well as the U and V matrices. + + KaldiBlasInt M = num_cols_; + KaldiBlasInt N = num_rows_; + KaldiBlasInt LDA = Stride(); + + KALDI_ASSERT(N>=M); // NumRows >= columns. + + if (U_in) { + KALDI_ASSERT((int)U_in->num_rows_ == N && (int)U_in->num_cols_ == M); + } + if (V_in) { + KALDI_ASSERT((int)V_in->num_rows_ == M && (int)V_in->num_cols_ == M); + } + KALDI_ASSERT((int)s->Dim() == std::min(M, N)); + + MatrixBase *U = (U_in ? U_in : &tmpU); + MatrixBase *V = (V_in ? V_in : &tmpV); + + KaldiBlasInt V_stride = V->Stride(); + KaldiBlasInt U_stride = U->Stride(); + + // Original LAPACK recipe + // KaldiBlasInt l_work = std::max(std::max + // (1, 3*std::min(M, N)+std::max(M, N)), 5*std::min(M, N))*2; + KaldiBlasInt l_work = -1; + Real work_query; + KaldiBlasInt result; + + // query for work space + char *u_job = const_cast(U_in ? "s" : "N"); // "s" == skinny, "N" == "none." + char *v_job = const_cast(V_in ? "s" : "N"); // "s" == skinny, "N" == "none." + clapack_Xgesvd(v_job, u_job, + &M, &N, data_, &LDA, + s->Data(), + V->Data(), &V_stride, + U->Data(), &U_stride, + &work_query, &l_work, + &result); + + KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong arguments"); + + l_work = static_cast(work_query); + Real *p_work; + void *temp; + if ((p_work = static_cast( + KALDI_MEMALIGN(16, sizeof(Real)*l_work, &temp))) == NULL) + throw std::bad_alloc(); + + // perform svd + clapack_Xgesvd(v_job, u_job, + &M, &N, data_, &LDA, + s->Data(), + V->Data(), &V_stride, + U->Data(), &U_stride, + p_work, &l_work, + &result); + + KALDI_ASSERT(result >= 0 && "Call to CLAPACK dgesvd_ called with wrong arguments"); + + if (result != 0) { + KALDI_WARN << "CLAPACK sgesvd_ : some weird convergence not satisfied"; + } + KALDI_MEMALIGN_FREE(p_work); +} + +#endif + +// Copy constructor. Copies data to newly allocated memory. +template +Matrix::Matrix (const MatrixBase & M, + MatrixTransposeType trans/*=kNoTrans*/) + : MatrixBase() { + if (trans == kNoTrans) { + Resize(M.num_rows_, M.num_cols_); + this->CopyFromMat(M); + } else { + Resize(M.num_cols_, M.num_rows_); + this->CopyFromMat(M, kTrans); + } +} + +// Copy constructor. Copies data to newly allocated memory. +template +Matrix::Matrix (const Matrix & M): + MatrixBase() { + Resize(M.num_rows_, M.num_cols_); + this->CopyFromMat(M); +} + +/// Copy constructor from another type. +template +template +Matrix::Matrix(const MatrixBase & M, + MatrixTransposeType trans) : MatrixBase() { + if (trans == kNoTrans) { + Resize(M.NumRows(), M.NumCols()); + this->CopyFromMat(M); + } else { + Resize(M.NumCols(), M.NumRows()); + this->CopyFromMat(M, kTrans); + } +} + +// Instantiate this constructor for float->double and double->float. +template +Matrix::Matrix(const MatrixBase & M, + MatrixTransposeType trans); +template +Matrix::Matrix(const MatrixBase & M, + MatrixTransposeType trans); + +template +inline void Matrix::Init(const MatrixIndexT rows, + const MatrixIndexT cols, + const MatrixStrideType stride_type) { + if (rows * cols == 0) { + KALDI_ASSERT(rows == 0 && cols == 0); + this->num_rows_ = 0; + this->num_cols_ = 0; + this->stride_ = 0; + this->data_ = NULL; + return; + } + KALDI_ASSERT(rows > 0 && cols > 0); + MatrixIndexT skip, stride; + size_t size; + void *data; // aligned memory block + void *temp; // memory block to be really freed + + // compute the size of skip and real cols + skip = ((16 / sizeof(Real)) - cols % (16 / sizeof(Real))) + % (16 / sizeof(Real)); + stride = cols + skip; + size = static_cast(rows) * static_cast(stride) + * sizeof(Real); + + // allocate the memory and set the right dimensions and parameters + if (NULL != (data = KALDI_MEMALIGN(16, size, &temp))) { + MatrixBase::data_ = static_cast (data); + MatrixBase::num_rows_ = rows; + MatrixBase::num_cols_ = cols; + MatrixBase::stride_ = (stride_type == kDefaultStride ? stride : cols); + } else { + throw std::bad_alloc(); + } +} + +template +void Matrix::Resize(const MatrixIndexT rows, + const MatrixIndexT cols, + MatrixResizeType resize_type, + MatrixStrideType stride_type) { + // the next block uses recursion to handle what we have to do if + // resize_type == kCopyData. + if (resize_type == kCopyData) { + if (this->data_ == NULL || rows == 0) resize_type = kSetZero; // nothing to copy. + else if (rows == this->num_rows_ && cols == this->num_cols_ && + (stride_type == kDefaultStride || this->stride_ == this->num_cols_)) { return; } // nothing to do. + else { + // set tmp to a matrix of the desired size; if new matrix + // is bigger in some dimension, zero it. + MatrixResizeType new_resize_type = + (rows > this->num_rows_ || cols > this->num_cols_) ? kSetZero : kUndefined; + Matrix tmp(rows, cols, new_resize_type, stride_type); + MatrixIndexT rows_min = std::min(rows, this->num_rows_), + cols_min = std::min(cols, this->num_cols_); + tmp.Range(0, rows_min, 0, cols_min). + CopyFromMat(this->Range(0, rows_min, 0, cols_min)); + tmp.Swap(this); + // and now let tmp go out of scope, deleting what was in *this. + return; + } + } + // At this point, resize_type == kSetZero or kUndefined. + + if (MatrixBase::data_ != NULL) { + if (rows == MatrixBase::num_rows_ + && cols == MatrixBase::num_cols_) { + if (resize_type == kSetZero) + this->SetZero(); + return; + } + else + Destroy(); + } + Init(rows, cols, stride_type); + if (resize_type == kSetZero) MatrixBase::SetZero(); +} + +template +template +void MatrixBase::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans) { + if (sizeof(Real) == sizeof(OtherReal) && + static_cast(M.Data()) == + static_cast(this->Data())) { + // CopyFromMat called on same data. Nothing to do (except sanity checks). + KALDI_ASSERT(Trans == kNoTrans && M.NumRows() == NumRows() && + M.NumCols() == NumCols() && M.Stride() == Stride()); + return; + } + if (Trans == kNoTrans) { + KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == M.NumCols()); + for (MatrixIndexT i = 0; i < num_rows_; i++) + (*this).Row(i).CopyFromVec(M.Row(i)); + } else { + KALDI_ASSERT(num_cols_ == M.NumRows() && num_rows_ == M.NumCols()); + int32 this_stride = stride_, other_stride = M.Stride(); + Real *this_data = data_; + const OtherReal *other_data = M.Data(); + for (MatrixIndexT i = 0; i < num_rows_; i++) + for (MatrixIndexT j = 0; j < num_cols_; j++) + this_data[i * this_stride + j] = other_data[j * other_stride + i]; + } +} + +// template instantiations. +template +void MatrixBase::CopyFromMat(const MatrixBase & M, + MatrixTransposeType Trans); +template +void MatrixBase::CopyFromMat(const MatrixBase & M, + MatrixTransposeType Trans); +template +void MatrixBase::CopyFromMat(const MatrixBase & M, + MatrixTransposeType Trans); +template +void MatrixBase::CopyFromMat(const MatrixBase & M, + MatrixTransposeType Trans); + +// Specialize the template for CopyFromSp for float, float. +template<> +template<> +void MatrixBase::CopyFromSp(const SpMatrix & M) { + KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == num_rows_); + MatrixIndexT num_rows = num_rows_, stride = stride_; + const float *Mdata = M.Data(); + float *row_data = data_, *col_data = data_; + for (MatrixIndexT i = 0; i < num_rows; i++) { + cblas_scopy(i+1, Mdata, 1, row_data, 1); // copy to the row. + cblas_scopy(i, Mdata, 1, col_data, stride); // copy to the column. + Mdata += i+1; + row_data += stride; + col_data += 1; + } +} + +// Specialize the template for CopyFromSp for double, double. +template<> +template<> +void MatrixBase::CopyFromSp(const SpMatrix & M) { + KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == num_rows_); + MatrixIndexT num_rows = num_rows_, stride = stride_; + const double *Mdata = M.Data(); + double *row_data = data_, *col_data = data_; + for (MatrixIndexT i = 0; i < num_rows; i++) { + cblas_dcopy(i+1, Mdata, 1, row_data, 1); // copy to the row. + cblas_dcopy(i, Mdata, 1, col_data, stride); // copy to the column. + Mdata += i+1; + row_data += stride; + col_data += 1; + } +} + + +template +template +void MatrixBase::CopyFromSp(const SpMatrix & M) { + KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == num_rows_); + // MORE EFFICIENT IF LOWER TRIANGULAR! Reverse code otherwise. + for (MatrixIndexT i = 0; i < num_rows_; i++) { + for (MatrixIndexT j = 0; j < i; j++) { + (*this)(j, i) = (*this)(i, j) = M(i, j); + } + (*this)(i, i) = M(i, i); + } +} + +// Instantiate this function +template +void MatrixBase::CopyFromSp(const SpMatrix & M); +template +void MatrixBase::CopyFromSp(const SpMatrix & M); + + +template +template +void MatrixBase::CopyFromTp(const TpMatrix & M, + MatrixTransposeType Trans) { + if (Trans == kNoTrans) { + KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == num_rows_); + SetZero(); + Real *out_i = data_; + const OtherReal *in_i = M.Data(); + for (MatrixIndexT i = 0; i < num_rows_; i++, out_i += stride_, in_i += i) { + for (MatrixIndexT j = 0; j <= i; j++) + out_i[j] = in_i[j]; + } + } else { + SetZero(); + KALDI_ASSERT(num_rows_ == M.NumRows() && num_cols_ == num_rows_); + MatrixIndexT stride = stride_; + Real *out_i = data_; + const OtherReal *in_i = M.Data(); + for (MatrixIndexT i = 0; i < num_rows_; i++, out_i ++, in_i += i) { + for (MatrixIndexT j = 0; j <= i; j++) + out_i[j*stride] = in_i[j]; + } + } +} + +template +void MatrixBase::CopyFromTp(const TpMatrix & M, + MatrixTransposeType trans); +template +void MatrixBase::CopyFromTp(const TpMatrix & M, + MatrixTransposeType trans); +template +void MatrixBase::CopyFromTp(const TpMatrix & M, + MatrixTransposeType trans); +template +void MatrixBase::CopyFromTp(const TpMatrix & M, + MatrixTransposeType trans); + + +template +void MatrixBase::CopyRowsFromVec(const VectorBase &rv) { + if (rv.Dim() == num_rows_*num_cols_) { + if (stride_ == num_cols_) { + // one big copy operation. + const Real *rv_data = rv.Data(); + std::memcpy(data_, rv_data, sizeof(Real)*num_rows_*num_cols_); + } else { + const Real *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real *row_data = RowData(r); + for (MatrixIndexT c = 0; c < num_cols_; c++) { + row_data[c] = rv_data[c]; + } + rv_data += num_cols_; + } + } + } else if (rv.Dim() == num_cols_) { + const Real *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) + std::memcpy(RowData(r), rv_data, sizeof(Real)*num_cols_); + } else { + KALDI_ERR << "Wrong sized arguments"; + } +} + +template +template +void MatrixBase::CopyRowsFromVec(const VectorBase &rv) { + if (rv.Dim() == num_rows_*num_cols_) { + const OtherReal *rv_data = rv.Data(); + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real *row_data = RowData(r); + for (MatrixIndexT c = 0; c < num_cols_; c++) { + row_data[c] = static_cast(rv_data[c]); + } + rv_data += num_cols_; + } + } else if (rv.Dim() == num_cols_) { + const OtherReal *rv_data = rv.Data(); + Real *first_row_data = RowData(0); + for (MatrixIndexT c = 0; c < num_cols_; c++) + first_row_data[c] = rv_data[c]; + for (MatrixIndexT r = 1; r < num_rows_; r++) + std::memcpy(RowData(r), first_row_data, sizeof(Real)*num_cols_); + } else { + KALDI_ERR << "Wrong sized arguments."; + } +} + + +template +void MatrixBase::CopyRowsFromVec(const VectorBase &rv); +template +void MatrixBase::CopyRowsFromVec(const VectorBase &rv); + +template +void MatrixBase::CopyColsFromVec(const VectorBase &rv) { + if (rv.Dim() == num_rows_*num_cols_) { + const Real *v_inc_data = rv.Data(); + Real *m_inc_data = data_; + + for (MatrixIndexT c = 0; c < num_cols_; c++) { + for (MatrixIndexT r = 0; r < num_rows_; r++) { + m_inc_data[r * stride_] = v_inc_data[r]; + } + v_inc_data += num_rows_; + m_inc_data ++; + } + } else if (rv.Dim() == num_rows_) { + const Real *v_inc_data = rv.Data(); + Real *m_inc_data = data_; + for (MatrixIndexT r = 0; r < num_rows_; r++) { + Real value = *(v_inc_data++); + for (MatrixIndexT c = 0; c < num_cols_; c++) + m_inc_data[c] = value; + m_inc_data += stride_; + } + } else { + KALDI_ERR << "Wrong size of arguments."; + } +} + + +template +void MatrixBase::CopyRowFromVec(const VectorBase &rv, const MatrixIndexT row) { + KALDI_ASSERT(rv.Dim() == num_cols_ && + static_cast(row) < + static_cast(num_rows_)); + + const Real *rv_data = rv.Data(); + Real *row_data = RowData(row); + + std::memcpy(row_data, rv_data, num_cols_ * sizeof(Real)); +} + +template +void MatrixBase::CopyDiagFromVec(const VectorBase &rv) { + KALDI_ASSERT(rv.Dim() == std::min(num_cols_, num_rows_)); + const Real *rv_data = rv.Data(), *rv_end = rv_data + rv.Dim(); + Real *my_data = this->Data(); + for (; rv_data != rv_end; rv_data++, my_data += (this->stride_+1)) + *my_data = *rv_data; +} + +template +void MatrixBase::CopyColFromVec(const VectorBase &rv, + const MatrixIndexT col) { + KALDI_ASSERT(rv.Dim() == num_rows_ && + static_cast(col) < + static_cast(num_cols_)); + + const Real *rv_data = rv.Data(); + Real *col_data = data_ + col; + + for (MatrixIndexT r = 0; r < num_rows_; r++) + col_data[r * stride_] = rv_data[r]; +} + + + +template +void Matrix::RemoveRow(MatrixIndexT i) { + KALDI_ASSERT(static_cast(i) < + static_cast(MatrixBase::num_rows_) + && "Access out of matrix"); + for (MatrixIndexT j = i + 1; j < MatrixBase::num_rows_; j++) + MatrixBase::Row(j-1).CopyFromVec( MatrixBase::Row(j)); + MatrixBase::num_rows_--; +} + +template +void Matrix::Destroy() { + // we need to free the data block if it was defined + if (NULL != MatrixBase::data_) + KALDI_MEMALIGN_FREE( MatrixBase::data_); + MatrixBase::data_ = NULL; + MatrixBase::num_rows_ = MatrixBase::num_cols_ + = MatrixBase::stride_ = 0; +} + + + +template +void MatrixBase::MulElements(const MatrixBase &a) { + KALDI_ASSERT(a.NumRows() == num_rows_ && a.NumCols() == num_cols_); + + if (num_cols_ == stride_ && num_cols_ == a.stride_) { + mul_elements(num_rows_ * num_cols_, a.data_, data_); + } else { + MatrixIndexT a_stride = a.stride_, stride = stride_; + Real *data = data_, *a_data = a.data_; + for (MatrixIndexT i = 0; i < num_rows_; i++) { + mul_elements(num_cols_, a_data, data); + a_data += a_stride; + data += stride; + } + } +} + +template +void MatrixBase::DivElements(const MatrixBase &a) { + KALDI_ASSERT(a.NumRows() == num_rows_ && a.NumCols() == num_cols_); + MatrixIndexT i; + MatrixIndexT j; + + for (i = 0; i < num_rows_; i++) { + for (j = 0; j < num_cols_; j++) { + (*this)(i, j) /= a(i, j); + } + } +} + +template +Real MatrixBase::Sum() const { + double sum = 0.0; + + for (MatrixIndexT i = 0; i < num_rows_; i++) { + for (MatrixIndexT j = 0; j < num_cols_; j++) { + sum += (*this)(i, j); + } + } + + return (Real)sum; +} + +template void MatrixBase::Max(const MatrixBase &A) { + KALDI_ASSERT(A.NumRows() == NumRows() && A.NumCols() == NumCols()); + for (MatrixIndexT row = 0; row < num_rows_; row++) { + Real *row_data = RowData(row); + const Real *other_row_data = A.RowData(row); + MatrixIndexT num_cols = num_cols_; + for (MatrixIndexT col = 0; col < num_cols; col++) { + row_data[col] = std::max(row_data[col], + other_row_data[col]); + } + } +} + +template void MatrixBase::Min(const MatrixBase &A) { + KALDI_ASSERT(A.NumRows() == NumRows() && A.NumCols() == NumCols()); + for (MatrixIndexT row = 0; row < num_rows_; row++) { + Real *row_data = RowData(row); + const Real *other_row_data = A.RowData(row); + MatrixIndexT num_cols = num_cols_; + for (MatrixIndexT col = 0; col < num_cols; col++) { + row_data[col] = std::min(row_data[col], + other_row_data[col]); + } + } +} + + +template void MatrixBase::Scale(Real alpha) { + if (alpha == 1.0) return; + if (num_rows_ == 0) return; + if (num_cols_ == stride_) { + cblas_Xscal(static_cast(num_rows_) * static_cast(num_cols_), + alpha, data_,1); + } else { + Real *data = data_; + for (MatrixIndexT i = 0; i < num_rows_; ++i, data += stride_) { + cblas_Xscal(num_cols_, alpha, data,1); + } + } +} + +template // scales each row by scale[i]. +void MatrixBase::MulRowsVec(const VectorBase &scale) { + KALDI_ASSERT(scale.Dim() == num_rows_); + MatrixIndexT M = num_rows_, N = num_cols_; + + for (MatrixIndexT i = 0; i < M; i++) { + Real this_scale = scale(i); + for (MatrixIndexT j = 0; j < N; j++) { + (*this)(i, j) *= this_scale; + } + } +} + + +template +void MatrixBase::MulRowsGroupMat(const MatrixBase &src) { + KALDI_ASSERT(src.NumRows() == this->NumRows() && + this->NumCols() % src.NumCols() == 0); + int32 group_size = this->NumCols() / src.NumCols(), + num_groups = this->NumCols() / group_size, + num_rows = this->NumRows(); + + for (MatrixIndexT i = 0; i < num_rows; i++) { + Real *data = this->RowData(i); + for (MatrixIndexT j = 0; j < num_groups; j++, data += group_size) { + Real scale = src(i, j); + cblas_Xscal(group_size, scale, data, 1); + } + } +} + +template +void MatrixBase::GroupPnormDeriv(const MatrixBase &input, + const MatrixBase &output, + Real power) { + KALDI_ASSERT(input.NumCols() == this->NumCols() && input.NumRows() == this->NumRows()); + KALDI_ASSERT(this->NumCols() % output.NumCols() == 0 && + this->NumRows() == output.NumRows()); + + int group_size = this->NumCols() / output.NumCols(), + num_rows = this->NumRows(), num_cols = this->NumCols(); + + if (power == 1.0) { + for (MatrixIndexT i = 0; i < num_rows; i++) { + for (MatrixIndexT j = 0; j < num_cols; j++) { + Real input_val = input(i, j); + (*this)(i, j) = (input_val == 0 ? 0 : (input_val > 0 ? 1 : -1)); + } + } + } else if (power == std::numeric_limits::infinity()) { + for (MatrixIndexT i = 0; i < num_rows; i++) { + for (MatrixIndexT j = 0; j < num_cols; j++) { + Real output_val = output(i, j / group_size), input_val = input(i, j); + if (output_val == 0) + (*this)(i, j) = 0; + else + (*this)(i, j) = (std::abs(input_val) == output_val ? 1.0 : 0.0) + * (input_val >= 0 ? 1 : -1); + } + } + } else { + for (MatrixIndexT i = 0; i < num_rows; i++) { + for (MatrixIndexT j = 0; j < num_cols; j++) { + Real output_val = output(i, j / group_size), + input_val = input(i, j); + if (output_val == 0) + (*this)(i, j) = 0; + else + (*this)(i, j) = pow(std::abs(input_val), power - 1) * + pow(output_val, 1 - power) * (input_val >= 0 ? 1 : -1) ; + } + } + } +} + +template +void MatrixBase::GroupMaxDeriv(const MatrixBase &input, + const MatrixBase &output) { + KALDI_ASSERT(input.NumCols() == this->NumCols() && + input.NumRows() == this->NumRows()); + KALDI_ASSERT(this->NumCols() % output.NumCols() == 0 && + this->NumRows() == output.NumRows()); + + int group_size = this->NumCols() / output.NumCols(), + num_rows = this->NumRows(), num_cols = this->NumCols(); + + for (MatrixIndexT i = 0; i < num_rows; i++) { + for (MatrixIndexT j = 0; j < num_cols; j++) { + Real input_val = input(i, j); + Real output_val = output(i, j / group_size); + (*this)(i, j) = (input_val == output_val ? 1 : 0); + } + } +} + +template // scales each column by scale[i]. +void MatrixBase::MulColsVec(const VectorBase &scale) { + KALDI_ASSERT(scale.Dim() == num_cols_); + for (MatrixIndexT i = 0; i < num_rows_; i++) { + for (MatrixIndexT j = 0; j < num_cols_; j++) { + Real this_scale = scale(j); + (*this)(i, j) *= this_scale; + } + } +} + +template +void MatrixBase::SetZero() { + if (num_cols_ == stride_) + memset(data_, 0, sizeof(Real)*num_rows_*num_cols_); + else + for (MatrixIndexT row = 0; row < num_rows_; row++) + memset(data_ + row*stride_, 0, sizeof(Real)*num_cols_); +} + +template +void MatrixBase::Set(Real value) { + for (MatrixIndexT row = 0; row < num_rows_; row++) { + for (MatrixIndexT col = 0; col < num_cols_; col++) { + (*this)(row, col) = value; + } + } +} + +template +void MatrixBase::SetUnit() { + SetZero(); + for (MatrixIndexT row = 0; row < std::min(num_rows_, num_cols_); row++) + (*this)(row, row) = 1.0; +} + +template +void MatrixBase::SetRandn() { + kaldi::RandomState rstate; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + Real *row_data = this->RowData(row); + MatrixIndexT nc = (num_cols_ % 2 == 1) ? num_cols_ - 1 : num_cols_; + for (MatrixIndexT col = 0; col < nc; col += 2) { + kaldi::RandGauss2(row_data + col, row_data + col + 1, &rstate); + } + if (nc != num_cols_) row_data[nc] = static_cast(kaldi::RandGauss(&rstate)); + } +} + +template +void MatrixBase::SetRandUniform() { + kaldi::RandomState rstate; + for (MatrixIndexT row = 0; row < num_rows_; row++) { + Real *row_data = this->RowData(row); + for (MatrixIndexT col = 0; col < num_cols_; col++, row_data++) { + *row_data = static_cast(kaldi::RandUniform(&rstate)); + } + } +} + +template +void MatrixBase::Write(std::ostream &os, bool binary) const { + if (!os.good()) { + KALDI_ERR << "Failed to write matrix to stream: stream not good"; + } + if (binary) { // Use separate binary and text formats, + // since in binary mode we need to know if it's float or double. + std::string my_token = (sizeof(Real) == 4 ? "FM" : "DM"); + + WriteToken(os, binary, my_token); + { + int32 rows = this->num_rows_; // make the size 32-bit on disk. + int32 cols = this->num_cols_; + KALDI_ASSERT(this->num_rows_ == (MatrixIndexT) rows); + KALDI_ASSERT(this->num_cols_ == (MatrixIndexT) cols); + WriteBasicType(os, binary, rows); + WriteBasicType(os, binary, cols); + } + if (Stride() == NumCols()) + os.write(reinterpret_cast (Data()), sizeof(Real) + * static_cast(num_rows_) * static_cast(num_cols_)); + else + for (MatrixIndexT i = 0; i < num_rows_; i++) + os.write(reinterpret_cast (RowData(i)), sizeof(Real) + * num_cols_); + if (!os.good()) { + KALDI_ERR << "Failed to write matrix to stream"; + } + } else { // text mode. + if (num_cols_ == 0) { + os << " [ ]\n"; + } else { + os << " ["; + for (MatrixIndexT i = 0; i < num_rows_; i++) { + os << "\n "; + for (MatrixIndexT j = 0; j < num_cols_; j++) + os << (*this)(i, j) << " "; + } + os << "]\n"; + } + } +} + + +template +void MatrixBase::Read(std::istream & is, bool binary, bool add) { + if (add) { + Matrix tmp(num_rows_, num_cols_); + tmp.Read(is, binary, false); // read without adding. + if (tmp.num_rows_ != this->num_rows_ || tmp.num_cols_ != this->num_cols_) + KALDI_ERR << "MatrixBase::Read, size mismatch " + << this->num_rows_ << ", " << this->num_cols_ + << " vs. " << tmp.num_rows_ << ", " << tmp.num_cols_; + this->AddMat(1.0, tmp); + return; + } + // now assume add == false. + + // In order to avoid rewriting this, we just declare a Matrix and + // use it to read the data, then copy. + Matrix tmp; + tmp.Read(is, binary, false); + if (tmp.NumRows() != NumRows() || tmp.NumCols() != NumCols()) { + KALDI_ERR << "MatrixBase::Read, size mismatch " + << NumRows() << " x " << NumCols() << " versus " + << tmp.NumRows() << " x " << tmp.NumCols(); + } + CopyFromMat(tmp); +} + + +template +void Matrix::Read(std::istream & is, bool binary, bool add) { + if (add) { + Matrix tmp; + tmp.Read(is, binary, false); // read without adding. + if (this->num_rows_ == 0) this->Resize(tmp.num_rows_, tmp.num_cols_); + else { + if (this->num_rows_ != tmp.num_rows_ || this->num_cols_ != tmp.num_cols_) { + if (tmp.num_rows_ == 0) return; // do nothing in this case. + else KALDI_ERR << "Matrix::Read, size mismatch " + << this->num_rows_ << ", " << this->num_cols_ + << " vs. " << tmp.num_rows_ << ", " << tmp.num_cols_; + } + } + this->AddMat(1.0, tmp); + return; + } + + // now assume add == false. + MatrixIndexT pos_at_start = is.tellg(); + std::ostringstream specific_error; + + if (binary) { // Read in binary mode. + int peekval = Peek(is, binary); + if (peekval == 'C') { + // This code enables us to read CompressedMatrix as a regular matrix. + CompressedMatrix compressed_mat; + compressed_mat.Read(is, binary); // at this point, add == false. + this->Resize(compressed_mat.NumRows(), compressed_mat.NumCols()); + compressed_mat.CopyToMat(this); + return; + } + const char *my_token = (sizeof(Real) == 4 ? "FM" : "DM"); + char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); + if (peekval == other_token_start) { // need to instantiate the other type to read it. + typedef typename OtherReal::Real OtherType; // if Real == float, OtherType == double, and vice versa. + Matrix other(this->num_rows_, this->num_cols_); + other.Read(is, binary, false); // add is false at this point anyway. + this->Resize(other.NumRows(), other.NumCols()); + this->CopyFromMat(other); + return; + } + std::string token; + ReadToken(is, binary, &token); + if (token != my_token) { + if (token.length() > 20) token = token.substr(0, 17) + "..."; + specific_error << ": Expected token " << my_token << ", got " << token; + goto bad; + } + int32 rows, cols; + ReadBasicType(is, binary, &rows); // throws on error. + ReadBasicType(is, binary, &cols); // throws on error. + if ((MatrixIndexT)rows != this->num_rows_ || (MatrixIndexT)cols != this->num_cols_) { + this->Resize(rows, cols); + } + if (this->Stride() == this->NumCols() && rows*cols!=0) { + is.read(reinterpret_cast(this->Data()), + sizeof(Real)*rows*cols); + if (is.fail()) goto bad; + } else { + for (MatrixIndexT i = 0; i < (MatrixIndexT)rows; i++) { + is.read(reinterpret_cast(this->RowData(i)), sizeof(Real)*cols); + if (is.fail()) goto bad; + } + } + if (is.eof()) return; + if (is.fail()) goto bad; + return; + } else { // Text mode. + std::string str; + is >> str; // get a token + if (is.fail()) { specific_error << ": Expected \"[\", got EOF"; goto bad; } + // if ((str.compare("DM") == 0) || (str.compare("FM") == 0)) { // Back compatibility. + // is >> str; // get #rows + // is >> str; // get #cols + // is >> str; // get "[" + // } + if (str == "[]") { Resize(0, 0); return; } // Be tolerant of variants. + else if (str != "[") { + if (str.length() > 20) str = str.substr(0, 17) + "..."; + specific_error << ": Expected \"[\", got \"" << str << '"'; + goto bad; + } + // At this point, we have read "[". + std::vector* > data; + std::vector *cur_row = new std::vector; + while (1) { + int i = is.peek(); + if (i == -1) { specific_error << "Got EOF while reading matrix data"; goto cleanup; } + else if (static_cast(i) == ']') { // Finished reading matrix. + is.get(); // eat the "]". + i = is.peek(); + if (static_cast(i) == '\r') { + is.get(); + is.get(); // get \r\n (must eat what we wrote) + } else if (static_cast(i) == '\n') { is.get(); } // get \n (must eat what we wrote) + if (is.fail()) { + KALDI_WARN << "After end of matrix data, read error."; + // we got the data we needed, so just warn for this error. + } + // Now process the data. + if (!cur_row->empty()) data.push_back(cur_row); + else delete(cur_row); + cur_row = NULL; + if (data.empty()) { this->Resize(0, 0); return; } + else { + int32 num_rows = data.size(), num_cols = data[0]->size(); + this->Resize(num_rows, num_cols); + for (int32 i = 0; i < num_rows; i++) { + if (static_cast(data[i]->size()) != num_cols) { + specific_error << "Matrix has inconsistent #cols: " << num_cols + << " vs." << data[i]->size() << " (processing row" + << i << ")"; + goto cleanup; + } + for (int32 j = 0; j < num_cols; j++) + (*this)(i, j) = (*(data[i]))[j]; + delete data[i]; + data[i] = NULL; + } + } + return; + } else if (static_cast(i) == '\n' || static_cast(i) == ';') { + // End of matrix row. + is.get(); + if (cur_row->size() != 0) { + data.push_back(cur_row); + cur_row = new std::vector; + cur_row->reserve(data.back()->size()); + } + } else if ( (i >= '0' && i <= '9') || i == '-' ) { // A number... + Real r; + is >> r; + if (is.fail()) { + specific_error << "Stream failure/EOF while reading matrix data."; + goto cleanup; + } + cur_row->push_back(r); + } else if (isspace(i)) { + is.get(); // eat the space and do nothing. + } else { // NaN or inf or error. + std::string str; + is >> str; + if (!KALDI_STRCASECMP(str.c_str(), "inf") || + !KALDI_STRCASECMP(str.c_str(), "infinity")) { + cur_row->push_back(std::numeric_limits::infinity()); + KALDI_WARN << "Reading infinite value into matrix."; + } else if (!KALDI_STRCASECMP(str.c_str(), "nan")) { + cur_row->push_back(std::numeric_limits::quiet_NaN()); + KALDI_WARN << "Reading NaN value into matrix."; + } else { + if (str.length() > 20) str = str.substr(0, 17) + "..."; + specific_error << "Expecting numeric matrix data, got " << str; + goto cleanup; + } + } + } + // Note, we never leave the while () loop before this + // line (we return from it.) + cleanup: // We only reach here in case of error in the while loop above. + if(cur_row != NULL) + delete cur_row; + for (size_t i = 0; i < data.size(); i++) + if(data[i] != NULL) + delete data[i]; + // and then go on to "bad" below, where we print error. + } +bad: + KALDI_ERR << "Failed to read matrix from stream. " << specific_error.str() + << " File position at start is " + << pos_at_start << ", currently " << is.tellg(); +} + + +// Constructor... note that this is not const-safe as it would +// be quite complicated to implement a "const SubMatrix" class that +// would not allow its contents to be changed. +template +SubMatrix::SubMatrix(const MatrixBase &M, + const MatrixIndexT ro, + const MatrixIndexT r, + const MatrixIndexT co, + const MatrixIndexT c) { + if (r == 0 || c == 0) { + // we support the empty sub-matrix as a special case. + KALDI_ASSERT(c == 0 && r == 0); + this->data_ = NULL; + this->num_cols_ = 0; + this->num_rows_ = 0; + this->stride_ = 0; + return; + } + KALDI_ASSERT(static_cast(ro) < + static_cast(M.num_rows_) && + static_cast(co) < + static_cast(M.num_cols_) && + static_cast(r) <= + static_cast(M.num_rows_ - ro) && + static_cast(c) <= + static_cast(M.num_cols_ - co)); + // point to the begining of window + MatrixBase::num_rows_ = r; + MatrixBase::num_cols_ = c; + MatrixBase::stride_ = M.Stride(); + MatrixBase::data_ = M.Data_workaround() + + static_cast(co) + + static_cast(ro) * static_cast(M.Stride()); +} + + +template +SubMatrix::SubMatrix(Real *data, + MatrixIndexT num_rows, + MatrixIndexT num_cols, + MatrixIndexT stride): + MatrixBase(data, num_cols, num_rows, stride) { // caution: reversed order! + if (data == NULL) { + KALDI_ASSERT(num_rows * num_cols == 0); + this->num_rows_ = 0; + this->num_cols_ = 0; + this->stride_ = 0; + } else { + KALDI_ASSERT(this->stride_ >= this->num_cols_); + } +} + + +template +void MatrixBase::Add(const Real alpha) { + Real *data = data_; + MatrixIndexT stride = stride_; + for (MatrixIndexT r = 0; r < num_rows_; r++) + for (MatrixIndexT c = 0; c < num_cols_; c++) + data[c + stride*r] += alpha; +} + +template +void MatrixBase::AddToDiag(const Real alpha) { + Real *data = data_; + MatrixIndexT this_stride = stride_ + 1, + num_to_add = std::min(num_rows_, num_cols_); + for (MatrixIndexT r = 0; r < num_to_add; r++) + data[r * this_stride] += alpha; +} + + +template +Real MatrixBase::Cond() const { + KALDI_ASSERT(num_rows_ > 0&&num_cols_ > 0); + Vector singular_values(std::min(num_rows_, num_cols_)); + Svd(&singular_values); // Get singular values... + Real min = singular_values(0), max = singular_values(0); // both absolute values... + for (MatrixIndexT i = 1;i < singular_values.Dim();i++) { + min = std::min((Real)std::abs(singular_values(i)), min); max = std::max((Real)std::abs(singular_values(i)), max); + } + if (min > 0) return max/min; + else return std::numeric_limits::infinity(); +} + +template +Real MatrixBase::Trace(bool check_square) const { + KALDI_ASSERT(!check_square || num_rows_ == num_cols_); + Real ans = 0.0; + for (MatrixIndexT r = 0;r < std::min(num_rows_, num_cols_);r++) ans += data_ [r + stride_*r]; + return ans; +} + +template +Real MatrixBase::Max() const { + KALDI_ASSERT(num_rows_ > 0 && num_cols_ > 0); + Real ans= *data_; + for (MatrixIndexT r = 0; r < num_rows_; r++) + for (MatrixIndexT c = 0; c < num_cols_; c++) + if (data_[c + stride_*r] > ans) + ans = data_[c + stride_*r]; + return ans; +} + +template +Real MatrixBase::Min() const { + KALDI_ASSERT(num_rows_ > 0 && num_cols_ > 0); + Real ans= *data_; + for (MatrixIndexT r = 0; r < num_rows_; r++) + for (MatrixIndexT c = 0; c < num_cols_; c++) + if (data_[c + stride_*r] < ans) + ans = data_[c + stride_*r]; + return ans; +} + + + +template +void MatrixBase::AddMatMatMat(Real alpha, + const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC, + Real beta) { + // Note on time taken with different orders of computation. Assume not transposed in this / + // discussion. Firstly, normalize expressions using A.NumCols == B.NumRows and B.NumCols == C.NumRows, prefer + // rows where there is a choice. + // time taken for (AB) is: A.NumRows*B.NumRows*C.Rows + // time taken for (AB)C is A.NumRows*C.NumRows*C.Cols + // so this order is A.NumRows*B.NumRows*C.NumRows + A.NumRows*C.NumRows*C.NumCols. + + // time taken for (BC) is: B.NumRows*C.NumRows*C.Cols + // time taken for A(BC) is: A.NumRows*B.NumRows*C.Cols + // so this order is B.NumRows*C.NumRows*C.NumCols + A.NumRows*B.NumRows*C.Cols + + MatrixIndexT ARows = A.num_rows_, ACols = A.num_cols_, BRows = B.num_rows_, BCols = B.num_cols_, + CRows = C.num_rows_, CCols = C.num_cols_; + if (transA == kTrans) std::swap(ARows, ACols); + if (transB == kTrans) std::swap(BRows, BCols); + if (transC == kTrans) std::swap(CRows, CCols); + + MatrixIndexT AB_C_time = ARows*BRows*CRows + ARows*CRows*CCols; + MatrixIndexT A_BC_time = BRows*CRows*CCols + ARows*BRows*CCols; + + if (AB_C_time < A_BC_time) { + Matrix AB(ARows, BCols); + AB.AddMatMat(1.0, A, transA, B, transB, 0.0); // AB = A * B. + (*this).AddMatMat(alpha, AB, kNoTrans, C, transC, beta); + } else { + Matrix BC(BRows, CCols); + BC.AddMatMat(1.0, B, transB, C, transC, 0.0); // BC = B * C. + (*this).AddMatMat(alpha, A, transA, BC, kNoTrans, beta); + } +} + + + + +template +void MatrixBase::DestructiveSvd(VectorBase *s, MatrixBase *U, MatrixBase *Vt) { + // Svd, *this = U*diag(s)*Vt. + // With (*this).num_rows_ == m, (*this).num_cols_ == n, + // Support only skinny Svd with m>=n (NumRows>=NumCols), and zero sizes for U and Vt mean + // we do not want that output. We expect that s.Dim() == m, + // U is either 0 by 0 or m by n, and rv is either 0 by 0 or n by n. + // Throws exception on error. + + KALDI_ASSERT(num_rows_>=num_cols_ && "Svd requires that #rows by >= #cols."); // For compatibility with JAMA code. + KALDI_ASSERT(s->Dim() == num_cols_); // s should be the smaller dim. + KALDI_ASSERT(U == NULL || (U->num_rows_ == num_rows_&&U->num_cols_ == num_cols_)); + KALDI_ASSERT(Vt == NULL || (Vt->num_rows_ == num_cols_&&Vt->num_cols_ == num_cols_)); + + Real prescale = 1.0; + if ( std::abs((*this)(0, 0) ) < 1.0e-30) { // Very tiny value... can cause problems in Svd. + Real max_elem = LargestAbsElem(); + if (max_elem != 0) { + prescale = 1.0 / max_elem; + if (std::abs(prescale) == std::numeric_limits::infinity()) { prescale = 1.0e+40; } + (*this).Scale(prescale); + } + } + +#if !defined(HAVE_ATLAS) && !defined(USE_KALDI_SVD) + // "S" == skinny Svd (only one we support because of compatibility with Jama one which is only skinny), + // "N"== no eigenvectors wanted. + LapackGesvd(s, U, Vt); +#else + /* if (num_rows_ > 1 && num_cols_ > 1 && (*this)(0, 0) == (*this)(1, 1) + && Max() == Min() && (*this)(0, 0) != 0.0) { // special case that JamaSvd sometimes crashes on. + KALDI_WARN << "Jama SVD crashes on this type of matrix, perturbing it to prevent crash."; + for(int32 i = 0; i < NumRows(); i++) + (*this)(i, i) *= 1.00001; + }*/ + bool ans = JamaSvd(s, U, Vt); + if (Vt != NULL) Vt->Transpose(); // possibly to do: change this and also the transpose inside the JamaSvd routine. note, Vt is square. + if (!ans) { + KALDI_ERR << "Error doing Svd"; // This one will be caught. + } +#endif + if (prescale != 1.0) s->Scale(1.0/prescale); +} + +template +void MatrixBase::Svd(VectorBase *s, MatrixBase *U, MatrixBase *Vt) const { + try { + if (num_rows_ >= num_cols_) { + Matrix tmp(*this); + tmp.DestructiveSvd(s, U, Vt); + } else { + Matrix tmp(*this, kTrans); // transpose of *this. + // rVt will have different dim so cannot transpose in-place --> use a temp matrix. + Matrix Vt_Trans(Vt ? Vt->num_cols_ : 0, Vt ? Vt->num_rows_ : 0); + // U will be transpose + tmp.DestructiveSvd(s, Vt ? &Vt_Trans : NULL, U); + if (U) U->Transpose(); + if (Vt) Vt->CopyFromMat(Vt_Trans, kTrans); // copy with transpose. + } + } catch (...) { + KALDI_ERR << "Error doing Svd (did not converge), first part of matrix is\n" + << SubMatrix(*this, 0, std::min((MatrixIndexT)10, num_rows_), + 0, std::min((MatrixIndexT)10, num_cols_)) + << ", min and max are: " << Min() << ", " << Max(); + } +} + +template +bool MatrixBase::IsSymmetric(Real cutoff) const { + MatrixIndexT R = num_rows_, C = num_cols_; + if (R != C) return false; + Real bad_sum = 0.0, good_sum = 0.0; + for (MatrixIndexT i = 0;i < R;i++) { + for (MatrixIndexT j = 0;j < i;j++) { + Real a = (*this)(i, j), b = (*this)(j, i), avg = 0.5*(a+b), diff = 0.5*(a-b); + good_sum += std::abs(avg); bad_sum += std::abs(diff); + } + good_sum += std::abs((*this)(i, i)); + } + if (bad_sum > cutoff*good_sum) return false; + return true; +} + +template +bool MatrixBase::IsDiagonal(Real cutoff) const{ + MatrixIndexT R = num_rows_, C = num_cols_; + Real bad_sum = 0.0, good_sum = 0.0; + for (MatrixIndexT i = 0;i < R;i++) { + for (MatrixIndexT j = 0;j < C;j++) { + if (i == j) good_sum += std::abs((*this)(i, j)); + else bad_sum += std::abs((*this)(i, j)); + } + } + return (!(bad_sum > good_sum * cutoff)); +} + +// This does nothing, it's designed to trigger Valgrind errors +// if any memory is uninitialized. +template +void MatrixBase::TestUninitialized() const { + MatrixIndexT R = num_rows_, C = num_cols_, positive = 0; + for (MatrixIndexT i = 0; i < R; i++) + for (MatrixIndexT j = 0; j < C; j++) + if ((*this)(i, j) > 0.0) positive++; + if (positive > R * C) + KALDI_ERR << "Error...."; +} + + +template +bool MatrixBase::IsUnit(Real cutoff) const { + MatrixIndexT R = num_rows_, C = num_cols_; + Real bad_max = 0.0; + for (MatrixIndexT i = 0; i < R;i++) + for (MatrixIndexT j = 0; j < C;j++) + bad_max = std::max(bad_max, static_cast(std::abs( (*this)(i, j) - (i == j?1.0:0.0)))); + return (bad_max <= cutoff); +} + +template +bool MatrixBase::IsZero(Real cutoff)const { + MatrixIndexT R = num_rows_, C = num_cols_; + Real bad_max = 0.0; + for (MatrixIndexT i = 0;i < R;i++) + for (MatrixIndexT j = 0;j < C;j++) + bad_max = std::max(bad_max, static_cast(std::abs( (*this)(i, j) ))); + return (bad_max <= cutoff); +} + +template +Real MatrixBase::FrobeniusNorm() const{ + return std::sqrt(TraceMatMat(*this, *this, kTrans)); +} + +template +bool MatrixBase::ApproxEqual(const MatrixBase &other, float tol) const { + if (num_rows_ != other.num_rows_ || num_cols_ != other.num_cols_) + KALDI_ERR << "ApproxEqual: size mismatch."; + Matrix tmp(*this); + tmp.AddMat(-1.0, other); + return (tmp.FrobeniusNorm() <= static_cast(tol) * + this->FrobeniusNorm()); +} + +template +bool MatrixBase::Equal(const MatrixBase &other) const { + if (num_rows_ != other.num_rows_ || num_cols_ != other.num_cols_) + KALDI_ERR << "Equal: size mismatch."; + for (MatrixIndexT i = 0; i < num_rows_; i++) + for (MatrixIndexT j = 0; j < num_cols_; j++) + if ( (*this)(i, j) != other(i, j)) + return false; + return true; +} + + +template +Real MatrixBase::LargestAbsElem() const{ + MatrixIndexT R = num_rows_, C = num_cols_; + Real largest = 0.0; + for (MatrixIndexT i = 0;i < R;i++) + for (MatrixIndexT j = 0;j < C;j++) + largest = std::max(largest, (Real)std::abs((*this)(i, j))); + return largest; +} + + +template +void MatrixBase::OrthogonalizeRows() { + KALDI_ASSERT(NumRows() <= NumCols()); + MatrixIndexT num_rows = num_rows_; + for (MatrixIndexT i = 0; i < num_rows; i++) { + int32 counter = 0; + while (1) { + Real start_prod = VecVec(this->Row(i), this->Row(i)); + if (start_prod - start_prod != 0.0 || start_prod == 0.0) { + KALDI_WARN << "Self-product of row " << i << " of matrix is " + << start_prod << ", randomizing."; + this->Row(i).SetRandn(); + counter++; + continue; // loop again. + } + for (MatrixIndexT j = 0; j < i; j++) { + Real prod = VecVec(this->Row(i), this->Row(j)); + this->Row(i).AddVec(-prod, this->Row(j)); + } + Real end_prod = VecVec(this->Row(i), this->Row(i)); + if (end_prod <= 0.01 * start_prod) { // We removed + // almost all of the vector during orthogonalization, + // so we have reason to doubt (for roundoff reasons) + // that it's still orthogonal to the other vectors. + // We need to orthogonalize again. + if (end_prod == 0.0) { // Row is exactly zero: + // generate random direction. + this->Row(i).SetRandn(); + } + counter++; + if (counter > 100) + KALDI_ERR << "Loop detected while orthogalizing matrix."; + } else { + this->Row(i).Scale(1.0 / std::sqrt(end_prod)); + break; + } + } + } +} + + +// Uses Svd to compute the eigenvalue decomposition of a symmetric positive semidefinite +// matrix: +// (*this) = rU * diag(rs) * rU^T, with rU an orthogonal matrix so rU^{-1} = rU^T. +// Does this by computing svd (*this) = U diag(rs) V^T ... answer is just U diag(rs) U^T. +// Throws exception if this failed to within supplied precision (typically because *this was not +// symmetric positive definite). + +template +void MatrixBase::SymPosSemiDefEig(VectorBase *rs, MatrixBase *rU, Real check_thresh) // e.g. check_thresh = 0.001 +{ + const MatrixIndexT D = num_rows_; + + KALDI_ASSERT(num_rows_ == num_cols_); + KALDI_ASSERT(IsSymmetric() && "SymPosSemiDefEig: expecting input to be symmetrical."); + KALDI_ASSERT(rU->num_rows_ == D && rU->num_cols_ == D && rs->Dim() == D); + + Matrix Vt(D, D); + Svd(rs, rU, &Vt); + + // First just zero any singular values if the column of U and V do not have +ve dot product-- + // this may mean we have small negative eigenvalues, and if we zero them the result will be closer to correct. + for (MatrixIndexT i = 0;i < D;i++) { + Real sum = 0.0; + for (MatrixIndexT j = 0;j < D;j++) sum += (*rU)(j, i) * Vt(i, j); + if (sum < 0.0) (*rs)(i) = 0.0; + } + + { + Matrix tmpU(*rU); Vector tmps(*rs); tmps.ApplyPow(0.5); + tmpU.MulColsVec(tmps); + SpMatrix tmpThis(D); + tmpThis.AddMat2(1.0, tmpU, kNoTrans, 0.0); + Matrix tmpThisFull(tmpThis); + float new_norm = tmpThisFull.FrobeniusNorm(); + float old_norm = (*this).FrobeniusNorm(); + tmpThisFull.AddMat(-1.0, (*this)); + + if (!(old_norm == 0 && new_norm == 0)) { + float diff_norm = tmpThisFull.FrobeniusNorm(); + if (std::abs(new_norm-old_norm) > old_norm*check_thresh || diff_norm > old_norm*check_thresh) { + KALDI_WARN << "SymPosSemiDefEig seems to have failed " << diff_norm << " !<< " + << check_thresh << "*" << old_norm << ", maybe matrix was not " + << "positive semi definite. Continuing anyway."; + } + } + } +} + + +template +Real MatrixBase::LogDet(Real *det_sign) const { + Real log_det; + Matrix tmp(*this); + tmp.Invert(&log_det, det_sign, false); // false== output not needed (saves some computation). + return log_det; +} + +template +void MatrixBase::InvertDouble(Real *log_det, Real *det_sign, + bool inverse_needed) { + double log_det_tmp, det_sign_tmp; + Matrix dmat(*this); + dmat.Invert(&log_det_tmp, &det_sign_tmp, inverse_needed); + if (inverse_needed) (*this).CopyFromMat(dmat); + if (log_det) *log_det = log_det_tmp; + if (det_sign) *det_sign = det_sign_tmp; +} + +template +void MatrixBase::CopyFromMat(const CompressedMatrix &mat) { + mat.CopyToMat(this); +} + +template +Matrix::Matrix(const CompressedMatrix &M): MatrixBase() { + Resize(M.NumRows(), M.NumCols(), kUndefined); + M.CopyToMat(this); +} + + + +template +void MatrixBase::InvertElements() { + for (MatrixIndexT r = 0; r < num_rows_; r++) { + for (MatrixIndexT c = 0; c < num_cols_; c++) { + (*this)(r, c) = static_cast(1.0 / (*this)(r, c)); + } + } +} + +template +void MatrixBase::Transpose() { + KALDI_ASSERT(num_rows_ == num_cols_); + MatrixIndexT M = num_rows_; + for (MatrixIndexT i = 0;i < M;i++) + for (MatrixIndexT j = 0;j < i;j++) { + Real &a = (*this)(i, j), &b = (*this)(j, i); + std::swap(a, b); + } +} + + +template +void Matrix::Transpose() { + if (this->num_rows_ != this->num_cols_) { + Matrix tmp(*this, kTrans); + Resize(this->num_cols_, this->num_rows_); + this->CopyFromMat(tmp); + } else { + (static_cast&>(*this)).Transpose(); + } +} + +template +void MatrixBase::Heaviside(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = (src_row_data[col] > 0 ? 1.0 : 0.0); + } +} + +template +void MatrixBase::Exp(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = kaldi::Exp(src_row_data[col]); + } +} + +template +void MatrixBase::Pow(const MatrixBase &src, Real power) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) { + row_data[col] = pow(src_row_data[col], power); + } + } +} + +template +void MatrixBase::PowAbs(const MatrixBase &src, Real power, bool include_sign) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col ++) { + if (include_sign == true && src_row_data[col] < 0) { + row_data[col] = -pow(std::abs(src_row_data[col]), power); + } else { + row_data[col] = pow(std::abs(src_row_data[col]), power); + } + } + } +} + +template +void MatrixBase::Floor(const MatrixBase &src, Real floor_val) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = (src_row_data[col] < floor_val ? floor_val : src_row_data[col]); + } +} + +template +void MatrixBase::Ceiling(const MatrixBase &src, Real ceiling_val) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = (src_row_data[col] > ceiling_val ? ceiling_val : src_row_data[col]); + } +} + +template +void MatrixBase::Log(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = kaldi::Log(src_row_data[col]); + } +} + +template +void MatrixBase::ExpSpecial(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) + row_data[col] = (src_row_data[col] < Real(0) ? kaldi::Exp(src_row_data[col]) : (src_row_data[col] + Real(1))); + } +} + +template +void MatrixBase::ExpLimited(const MatrixBase &src, Real lower_limit, Real upper_limit) { + KALDI_ASSERT(SameDim(*this, src)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_; + Real *row_data = data_; + const Real *src_row_data = src.Data(); + for (MatrixIndexT row = 0; row < num_rows; + row++,row_data += stride_, src_row_data += src.stride_) { + for (MatrixIndexT col = 0; col < num_cols; col++) { + const Real x = src_row_data[col]; + if (!(x >= lower_limit)) + row_data[col] = kaldi::Exp(lower_limit); + else if (x > upper_limit) + row_data[col] = kaldi::Exp(upper_limit); + else + row_data[col] = kaldi::Exp(x); + } + } +} + +template +bool MatrixBase::Power(Real power) { + KALDI_ASSERT(num_rows_ > 0 && num_rows_ == num_cols_); + MatrixIndexT n = num_rows_; + Matrix P(n, n); + Vector re(n), im(n); + this->Eig(&P, &re, &im); + // Now attempt to take the complex eigenvalues to this power. + for (MatrixIndexT i = 0; i < n; i++) + if (!AttemptComplexPower(&(re(i)), &(im(i)), power)) + return false; // e.g. real and negative, or zero, eigenvalues. + + Matrix D(n, n); // D to the power. + CreateEigenvalueMatrix(re, im, &D); + + Matrix tmp(n, n); // P times D + tmp.AddMatMat(1.0, P, kNoTrans, D, kNoTrans, 0.0); // tmp := P*D + P.Invert(); + // next line is: *this = tmp * P^{-1} = P * D * P^{-1} + (*this).AddMatMat(1.0, tmp, kNoTrans, P, kNoTrans, 0.0); + return true; +} + +template +void Matrix::Swap(Matrix *other) { + std::swap(this->data_, other->data_); + std::swap(this->num_cols_, other->num_cols_); + std::swap(this->num_rows_, other->num_rows_); + std::swap(this->stride_, other->stride_); +} + +// Repeating this comment that appeared in the header: +// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D +// P^{-1}. Be careful: the relationship of D to the eigenvalues we output is +// slightly complicated, due to the need for P to be real. In the symmetric +// case D is diagonal and real, but in +// the non-symmetric case there may be complex-conjugate pairs of eigenvalues. +// In this case, for the equation (*this) = P D P^{-1} to hold, D must actually +// be block diagonal, with 2x2 blocks corresponding to any such pairs. If a +// pair is lambda +- i*mu, D will have a corresponding 2x2 block +// [lambda, mu; -mu, lambda]. +// Note that if the input matrix (*this) is non-invertible, P may not be invertible +// so in this case instead of the equation (*this) = P D P^{-1} holding, we have +// instead (*this) P = P D. +// +// By making the pointer arguments non-NULL or NULL, the user can choose to take +// not to take the eigenvalues directly, and/or the matrix D which is block-diagonal +// with 2x2 blocks. +template +void MatrixBase::Eig(MatrixBase *P, + VectorBase *r, + VectorBase *i) const { + EigenvalueDecomposition eig(*this); + if (P) eig.GetV(P); + if (r) eig.GetRealEigenvalues(r); + if (i) eig.GetImagEigenvalues(i); +} + + +// Begin non-member function definitions. + +// /** +// * @brief Extension of the HTK header +// */ +// struct HtkHeaderExt +// { +// INT_32 mHeaderSize; +// INT_32 mVersion; +// INT_32 mSampSize; +// }; + +template +bool ReadHtk(std::istream &is, Matrix *M_ptr, HtkHeader *header_ptr) +{ + // check instantiated with double or float. + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + Matrix &M = *M_ptr; + HtkHeader htk_hdr; + + // TODO(arnab): this fails if the HTK file has CRC cheksum or is compressed. + is.read((char*)&htk_hdr, sizeof(htk_hdr)); // we're being really POSIX here! + if (is.fail()) { + KALDI_WARN << "Could not read header from HTK feature file "; + return false; + } + + KALDI_SWAP4(htk_hdr.mNSamples); + KALDI_SWAP4(htk_hdr.mSamplePeriod); + KALDI_SWAP2(htk_hdr.mSampleSize); + KALDI_SWAP2(htk_hdr.mSampleKind); + + bool has_checksum = false; + { + // See HParm.h in HTK code for sources of these things. + enum BaseParmKind{ + Waveform, Lpc, Lprefc, Lpcepstra, Lpdelcep, + Irefc, Mfcc, Fbank, Melspec, User, Discrete, Plp, Anon }; + + const int32 IsCompressed = 02000, HasChecksum = 010000, HasVq = 040000, + Problem = IsCompressed | HasVq; + int32 base_parm = htk_hdr.mSampleKind & (077); + has_checksum = (base_parm & HasChecksum) != 0; + htk_hdr.mSampleKind &= ~HasChecksum; // We don't support writing with + // checksum so turn it off. + if (htk_hdr.mSampleKind & Problem) + KALDI_ERR << "Code to read HTK features does not support compressed " + "features, or features with VQ."; + if (base_parm == Waveform || base_parm == Irefc || base_parm == Discrete) + KALDI_ERR << "Attempting to read HTK features from unsupported type " + "(e.g. waveform or discrete features."; + } + + KALDI_VLOG(3) << "HTK header: Num Samples: " << htk_hdr.mNSamples + << "; Sample period: " << htk_hdr.mSamplePeriod + << "; Sample size: " << htk_hdr.mSampleSize + << "; Sample kind: " << htk_hdr.mSampleKind; + + M.Resize(htk_hdr.mNSamples, htk_hdr.mSampleSize / sizeof(float)); + + MatrixIndexT i; + MatrixIndexT j; + if (sizeof(Real) == sizeof(float)) { + for (i = 0; i< M.NumRows(); i++) { + is.read((char*)M.RowData(i), sizeof(float)*M.NumCols()); + if (is.fail()) { + KALDI_WARN << "Could not read data from HTK feature file "; + return false; + } + if (MachineIsLittleEndian()) { + MatrixIndexT C = M.NumCols(); + for (j = 0; j < C; j++) { + KALDI_SWAP4((M(i, j))); // The HTK standard is big-endian! + } + } + } + } else { + float *pmem = new float[M.NumCols()]; + for (i = 0; i < M.NumRows(); i++) { + is.read((char*)pmem, sizeof(float)*M.NumCols()); + if (is.fail()) { + KALDI_WARN << "Could not read data from HTK feature file "; + delete [] pmem; + return false; + } + MatrixIndexT C = M.NumCols(); + for (j = 0; j < C; j++) { + if (MachineIsLittleEndian()) // HTK standard is big-endian! + KALDI_SWAP4(pmem[j]); + M(i, j) = static_cast(pmem[j]); + } + } + delete [] pmem; + } + if (header_ptr) *header_ptr = htk_hdr; + if (has_checksum) { + int16 checksum; + is.read((char*)&checksum, sizeof(checksum)); + if (is.fail()) + KALDI_WARN << "Could not read checksum from HTK feature file "; + // We ignore the checksum. + } + return true; +} + + +template +bool ReadHtk(std::istream &is, Matrix *M, HtkHeader *header_ptr); + +template +bool ReadHtk(std::istream &is, Matrix *M, HtkHeader *header_ptr); + +template +bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr) // header may be derived from a previous call to ReadHtk. Must be in binary mode. +{ + KALDI_ASSERT(M.NumRows() == static_cast(htk_hdr.mNSamples)); + KALDI_ASSERT(M.NumCols() == static_cast(htk_hdr.mSampleSize) / + static_cast(sizeof(float))); + + KALDI_SWAP4(htk_hdr.mNSamples); + KALDI_SWAP4(htk_hdr.mSamplePeriod); + KALDI_SWAP2(htk_hdr.mSampleSize); + KALDI_SWAP2(htk_hdr.mSampleKind); + + os.write((char*)&htk_hdr, sizeof(htk_hdr)); + if (os.fail()) goto bad; + + MatrixIndexT i; + MatrixIndexT j; + if (sizeof(Real) == sizeof(float) && !MachineIsLittleEndian()) { + for (i = 0; i< M.NumRows(); i++) { // Unlikely to reach here ever! + os.write((char*)M.RowData(i), sizeof(float)*M.NumCols()); + if (os.fail()) goto bad; + } + } else { + float *pmem = new float[M.NumCols()]; + + for (i = 0; i < M.NumRows(); i++) { + const Real *rowData = M.RowData(i); + for (j = 0;j < M.NumCols();j++) + pmem[j] = static_cast ( rowData[j] ); + if (MachineIsLittleEndian()) + for (j = 0;j < M.NumCols();j++) + KALDI_SWAP4(pmem[j]); + os.write((char*)pmem, sizeof(float)*M.NumCols()); + if (os.fail()) { + delete [] pmem; + goto bad; + } + } + delete [] pmem; + } + return true; +bad: + KALDI_WARN << "Could not write to HTK feature file "; + return false; +} + +template +bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr); + +template +bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr); + +template +bool WriteSphinx(std::ostream &os, const MatrixBase &M) +{ + // CMUSphinx mfc file header contains count of the floats, followed + // by the data in float little endian format. + + int size = M.NumRows() * M.NumCols(); + os.write((char*)&size, sizeof(int)); + if (os.fail()) goto bad; + + MatrixIndexT i; + MatrixIndexT j; + if (sizeof(Real) == sizeof(float) && MachineIsLittleEndian()) { + for (i = 0; i< M.NumRows(); i++) { // Unlikely to reach here ever! + os.write((char*)M.RowData(i), sizeof(float)*M.NumCols()); + if (os.fail()) goto bad; + } + } else { + float *pmem = new float[M.NumCols()]; + + for (i = 0; i < M.NumRows(); i++) { + const Real *rowData = M.RowData(i); + for (j = 0;j < M.NumCols();j++) + pmem[j] = static_cast ( rowData[j] ); + if (!MachineIsLittleEndian()) + for (j = 0;j < M.NumCols();j++) + KALDI_SWAP4(pmem[j]); + os.write((char*)pmem, sizeof(float)*M.NumCols()); + if (os.fail()) { + delete [] pmem; + goto bad; + } + } + delete [] pmem; + } + return true; +bad: + KALDI_WARN << "Could not write to Sphinx feature file"; + return false; +} + +template +bool WriteSphinx(std::ostream &os, const MatrixBase &M); + +template +bool WriteSphinx(std::ostream &os, const MatrixBase &M); + +template +Real TraceMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC) { + MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), BCols = B.NumCols(), + CRows = C.NumRows(), CCols = C.NumCols(); + if (transA == kTrans) std::swap(ARows, ACols); + if (transB == kTrans) std::swap(BRows, BCols); + if (transC == kTrans) std::swap(CRows, CCols); + KALDI_ASSERT( CCols == ARows && ACols == BRows && BCols == CRows && "TraceMatMatMat: args have mismatched dimensions."); + if (ARows*BCols < std::min(BRows*CCols, CRows*ACols)) { + Matrix AB(ARows, BCols); + AB.AddMatMat(1.0, A, transA, B, transB, 0.0); // AB = A * B. + return TraceMatMat(AB, C, transC); + } else if ( BRows*CCols < CRows*ACols) { + Matrix BC(BRows, CCols); + BC.AddMatMat(1.0, B, transB, C, transC, 0.0); // BC = B * C. + return TraceMatMat(BC, A, transA); + } else { + Matrix CA(CRows, ACols); + CA.AddMatMat(1.0, C, transC, A, transA, 0.0); // CA = C * A + return TraceMatMat(CA, B, transB); + } +} + +template +float TraceMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC); + +template +double TraceMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC); + + +template +Real TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC, + const MatrixBase &D, MatrixTransposeType transD) { + MatrixIndexT ARows = A.NumRows(), ACols = A.NumCols(), BRows = B.NumRows(), BCols = B.NumCols(), + CRows = C.NumRows(), CCols = C.NumCols(), DRows = D.NumRows(), DCols = D.NumCols(); + if (transA == kTrans) std::swap(ARows, ACols); + if (transB == kTrans) std::swap(BRows, BCols); + if (transC == kTrans) std::swap(CRows, CCols); + if (transD == kTrans) std::swap(DRows, DCols); + KALDI_ASSERT( DCols == ARows && ACols == BRows && BCols == CRows && CCols == DRows && "TraceMatMatMat: args have mismatched dimensions."); + if (ARows*BCols < std::min(BRows*CCols, std::min(CRows*DCols, DRows*ACols))) { + Matrix AB(ARows, BCols); + AB.AddMatMat(1.0, A, transA, B, transB, 0.0); // AB = A * B. + return TraceMatMatMat(AB, kNoTrans, C, transC, D, transD); + } else if ((BRows*CCols) < std::min(CRows*DCols, DRows*ACols)) { + Matrix BC(BRows, CCols); + BC.AddMatMat(1.0, B, transB, C, transC, 0.0); // BC = B * C. + return TraceMatMatMat(BC, kNoTrans, D, transD, A, transA); + } else if (CRows*DCols < DRows*ACols) { + Matrix CD(CRows, DCols); + CD.AddMatMat(1.0, C, transC, D, transD, 0.0); // CD = C * D + return TraceMatMatMat(CD, kNoTrans, A, transA, B, transB); + } else { + Matrix DA(DRows, ACols); + DA.AddMatMat(1.0, D, transD, A, transA, 0.0); // DA = D * A + return TraceMatMatMat(DA, kNoTrans, B, transB, C, transC); + } +} + +template +float TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC, + const MatrixBase &D, MatrixTransposeType transD); + +template +double TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC, + const MatrixBase &D, MatrixTransposeType transD); + +template void SortSvd(VectorBase *s, MatrixBase *U, + MatrixBase *Vt, bool sort_on_absolute_value) { + /// Makes sure the Svd is sorted (from greatest to least absolute value). + MatrixIndexT num_singval = s->Dim(); + KALDI_ASSERT(U == NULL || U->NumCols() == num_singval); + KALDI_ASSERT(Vt == NULL || Vt->NumRows() == num_singval); + + std::vector > vec(num_singval); + // negative because we want revese order. + for (MatrixIndexT d = 0; d < num_singval; d++) { + Real val = (*s)(d), + sort_val = -(sort_on_absolute_value ? std::abs(val) : val); + vec[d] = std::pair(sort_val, d); + } + std::sort(vec.begin(), vec.end()); + Vector s_copy(*s); + for (MatrixIndexT d = 0; d < num_singval; d++) + (*s)(d) = s_copy(vec[d].second); + if (U != NULL) { + Matrix Utmp(*U); + MatrixIndexT dim = Utmp.NumRows(); + for (MatrixIndexT d = 0; d < num_singval; d++) { + MatrixIndexT oldidx = vec[d].second; + for (MatrixIndexT e = 0; e < dim; e++) + (*U)(e, d) = Utmp(e, oldidx); + } + } + if (Vt != NULL) { + Matrix Vttmp(*Vt); + for (MatrixIndexT d = 0; d < num_singval; d++) + (*Vt).Row(d).CopyFromVec(Vttmp.Row(vec[d].second)); + } +} + +template +void SortSvd(VectorBase *s, MatrixBase *U, + MatrixBase *Vt, bool); + +template +void SortSvd(VectorBase *s, MatrixBase *U, + MatrixBase *Vt, bool); + +template +void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, + MatrixBase *D) { + MatrixIndexT n = re.Dim(); + KALDI_ASSERT(im.Dim() == n && D->NumRows() == n && D->NumCols() == n); + + MatrixIndexT j = 0; + D->SetZero(); + while (j < n) { + if (im(j) == 0) { // Real eigenvalue + (*D)(j, j) = re(j); + j++; + } else { // First of a complex pair + KALDI_ASSERT(j+1 < n && ApproxEqual(im(j+1), -im(j)) + && ApproxEqual(re(j+1), re(j))); + /// if (im(j) < 0.0) KALDI_WARN << "Negative first im part of pair"; // TEMP + Real lambda = re(j), mu = im(j); + // create 2x2 block [lambda, mu; -mu, lambda] + (*D)(j, j) = lambda; + (*D)(j, j+1) = mu; + (*D)(j+1, j) = -mu; + (*D)(j+1, j+1) = lambda; + j += 2; + } + } +} + +template +void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, + MatrixBase *D); +template +void CreateEigenvalueMatrix(const VectorBase &re, const VectorBase &im, + MatrixBase *D); + + + +template +bool AttemptComplexPower(Real *x_re, Real *x_im, Real power) { + // Used in Matrix::Power(). + // Attempts to take the complex value x to the power "power", + // assuming that power is fractional (i.e. we don't treat integers as a + // special case). Returns false if this is not possible, either + // because x is negative and real (hence there is no obvious answer + // that is "closest to 1", and anyway this case does not make sense + // in the Matrix::Power() routine); + // or because power is negative, and x is zero. + + // First solve for r and theta in + // x_re = r*cos(theta), x_im = r*sin(theta) + if (*x_re < 0.0 && *x_im == 0.0) return false; // can't do + // it for negative real values. + Real r = std::sqrt((*x_re * *x_re) + (*x_im * *x_im)); // r == radius. + if (power < 0.0 && r == 0.0) return false; + Real theta = std::atan2(*x_im, *x_re); + // Take the power. + r = std::pow(r, power); + theta *= power; + *x_re = r * std::cos(theta); + *x_im = r * std::sin(theta); + return true; +} + +template +bool AttemptComplexPower(float *x_re, float *x_im, float power); +template +bool AttemptComplexPower(double *x_re, double *x_im, double power); + + + +template +Real TraceMatMat(const MatrixBase &A, + const MatrixBase &B, + MatrixTransposeType trans) { // tr(A B), equivalent to sum of each element of A times same element in B' + MatrixIndexT aStride = A.stride_, bStride = B.stride_; + if (trans == kNoTrans) { + KALDI_ASSERT(A.NumRows() == B.NumCols() && A.NumCols() == B.NumRows()); + Real ans = 0.0; + Real *adata = A.data_, *bdata = B.data_; + MatrixIndexT arows = A.NumRows(), acols = A.NumCols(); + for (MatrixIndexT row = 0;row < arows;row++, adata+=aStride, bdata++) + ans += cblas_Xdot(acols, adata, 1, bdata, bStride); + return ans; + } else { + KALDI_ASSERT(A.NumRows() == B.NumRows() && A.NumCols() == B.NumCols()); + Real ans = 0.0; + Real *adata = A.data_, *bdata = B.data_; + MatrixIndexT arows = A.NumRows(), acols = A.NumCols(); + for (MatrixIndexT row = 0;row < arows;row++, adata+=aStride, bdata+=bStride) + ans += cblas_Xdot(acols, adata, 1, bdata, 1); + return ans; + } +} + + +// Instantiate the template above for float and double. +template +float TraceMatMat(const MatrixBase &A, + const MatrixBase &B, + MatrixTransposeType trans); +template +double TraceMatMat(const MatrixBase &A, + const MatrixBase &B, + MatrixTransposeType trans); + + +template +Real MatrixBase::LogSumExp(Real prune) const { + Real sum; + if (sizeof(sum) == 8) sum = kLogZeroDouble; + else sum = kLogZeroFloat; + Real max_elem = Max(), cutoff; + if (sizeof(Real) == 4) cutoff = max_elem + kMinLogDiffFloat; + else cutoff = max_elem + kMinLogDiffDouble; + if (prune > 0.0 && max_elem - prune > cutoff) // explicit pruning... + cutoff = max_elem - prune; + + double sum_relto_max_elem = 0.0; + + for (MatrixIndexT i = 0; i < num_rows_; i++) { + for (MatrixIndexT j = 0; j < num_cols_; j++) { + BaseFloat f = (*this)(i, j); + if (f >= cutoff) + sum_relto_max_elem += kaldi::Exp(f - max_elem); + } + } + return max_elem + kaldi::Log(sum_relto_max_elem); +} + +template +Real MatrixBase::ApplySoftMax() { + Real max = this->Max(), sum = 0.0; + // the 'max' helps to get in good numeric range. + for (MatrixIndexT i = 0; i < num_rows_; i++) + for (MatrixIndexT j = 0; j < num_cols_; j++) + sum += ((*this)(i, j) = kaldi::Exp((*this)(i, j) - max)); + this->Scale(1.0 / sum); + return max + kaldi::Log(sum); +} + +template +void MatrixBase::Tanh(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); + + if (num_cols_ == stride_ && src.num_cols_ == src.stride_) { + SubVector src_vec(src.data_, num_rows_ * num_cols_), + dst_vec(this->data_, num_rows_ * num_cols_); + dst_vec.Tanh(src_vec); + } else { + for (MatrixIndexT r = 0; r < num_rows_; r++) { + SubVector src_vec(src, r), dest_vec(*this, r); + dest_vec.Tanh(src_vec); + } + } +} + +template +void MatrixBase::SoftHinge(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); + int32 num_rows = num_rows_, num_cols = num_cols_; + for (MatrixIndexT r = 0; r < num_rows; r++) { + Real *row_data = this->RowData(r); + const Real *src_row_data = src.RowData(r); + for (MatrixIndexT c = 0; c < num_cols; c++) { + Real x = src_row_data[c], y; + if (x > 10.0) y = x; // avoid exponentiating large numbers; function + // approaches y=x. + else y = Log1p(kaldi::Exp(x)); // these defined in kaldi-math.h + row_data[c] = y; + } + } +} + +template +void MatrixBase::GroupPnorm(const MatrixBase &src, Real power) { + KALDI_ASSERT(src.NumCols() % this->NumCols() == 0 && + src.NumRows() == this->NumRows()); + int group_size = src.NumCols() / this->NumCols(), + num_rows = this->NumRows(), num_cols = this->NumCols(); + for (MatrixIndexT i = 0; i < num_rows; i++) + for (MatrixIndexT j = 0; j < num_cols; j++) + (*this)(i, j) = src.Row(i).Range(j * group_size, group_size).Norm(power); +} + +template +void MatrixBase::GroupMax(const MatrixBase &src) { + KALDI_ASSERT(src.NumCols() % this->NumCols() == 0 && + src.NumRows() == this->NumRows()); + int group_size = src.NumCols() / this->NumCols(), + num_rows = this->NumRows(), num_cols = this->NumCols(); + for (MatrixIndexT i = 0; i < num_rows; i++) { + const Real *src_row_data = src.RowData(i); + for (MatrixIndexT j = 0; j < num_cols; j++) { + Real max_val = -1e20; + for (MatrixIndexT k = 0; k < group_size; k++) { + Real src_data = src_row_data[j * group_size + k]; + if (src_data > max_val) + max_val = src_data; + } + (*this)(i, j) = max_val; + } + } +} + +template +void MatrixBase::CopyCols(const MatrixBase &src, + const MatrixIndexT *indices) { + KALDI_ASSERT(NumRows() == src.NumRows()); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, + this_stride = stride_, src_stride = src.stride_; + Real *this_data = this->data_; + const Real *src_data = src.data_; +#ifdef KALDI_PARANOID + MatrixIndexT src_cols = src.NumCols(); + for (MatrixIndexT i = 0; i < num_cols; i++) + KALDI_ASSERT(indices[i] >= -1 && indices[i] < src_cols); +#endif + + // For the sake of memory locality we do this row by row, rather + // than doing it column-wise using cublas_Xcopy + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data += src_stride) { + const MatrixIndexT *index_ptr = &(indices[0]); + for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { + if (*index_ptr < 0) this_data[c] = 0; + else this_data[c] = src_data[*index_ptr]; + } + } +} + + +template +void MatrixBase::AddCols(const MatrixBase &src, + const MatrixIndexT *indices) { + KALDI_ASSERT(NumRows() == src.NumRows()); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, + this_stride = stride_, src_stride = src.stride_; + Real *this_data = this->data_; + const Real *src_data = src.data_; +#ifdef KALDI_PARANOID + MatrixIndexT src_cols = src.NumCols(); + for (MatrixIndexT i = 0; i < num_cols; i++) + KALDI_ASSERT(indices[i] >= -1 && indices[i] < src_cols); +#endif + + // For the sake of memory locality we do this row by row, rather + // than doing it column-wise using cublas_Xcopy + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride, src_data += src_stride) { + const MatrixIndexT *index_ptr = &(indices[0]); + for (MatrixIndexT c = 0; c < num_cols; c++, index_ptr++) { + if (*index_ptr >= 0) + this_data[c] += src_data[*index_ptr]; + } + } +} + +template +void MatrixBase::CopyRows(const MatrixBase &src, + const MatrixIndexT *indices) { + KALDI_ASSERT(NumCols() == src.NumCols()); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, + this_stride = stride_; + Real *this_data = this->data_; + + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride) { + MatrixIndexT index = indices[r]; + if (index < 0) memset(this_data, 0, sizeof(Real) * num_cols_); + else cblas_Xcopy(num_cols, src.RowData(index), 1, this_data, 1); + } +} + +template +void MatrixBase::CopyRows(const Real *const *src) { + MatrixIndexT num_rows = num_rows_, + num_cols = num_cols_, this_stride = stride_; + Real *this_data = this->data_; + + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride) { + const Real *const src_data = src[r]; + if (src_data == NULL) memset(this_data, 0, sizeof(Real) * num_cols); + else cblas_Xcopy(num_cols, src_data, 1, this_data, 1); + } +} + +template +void MatrixBase::CopyToRows(Real *const *dst) const { + MatrixIndexT num_rows = num_rows_, + num_cols = num_cols_, this_stride = stride_; + const Real *this_data = this->data_; + + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride) { + Real *const dst_data = dst[r]; + if (dst_data != NULL) + cblas_Xcopy(num_cols, this_data, 1, dst_data, 1); + } +} + +template +void MatrixBase::AddRows(Real alpha, + const MatrixBase &src, + const MatrixIndexT *indexes) { + KALDI_ASSERT(NumCols() == src.NumCols()); + MatrixIndexT num_rows = num_rows_, + num_cols = num_cols_, this_stride = stride_; + Real *this_data = this->data_; + + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride) { + MatrixIndexT index = indexes[r]; + KALDI_ASSERT(index >= -1 && index < src.NumRows()); + if (index != -1) + cblas_Xaxpy(num_cols, alpha, src.RowData(index), 1, this_data, 1); + } +} + +template +void MatrixBase::AddRows(Real alpha, const Real *const *src) { + MatrixIndexT num_rows = num_rows_, + num_cols = num_cols_, this_stride = stride_; + Real *this_data = this->data_; + + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride) { + const Real *const src_data = src[r]; + if (src_data != NULL) + cblas_Xaxpy(num_cols, alpha, src_data, 1, this_data, 1); + } +} + +template +void MatrixBase::AddToRows(Real alpha, + const MatrixIndexT *indexes, + MatrixBase *dst) const { + KALDI_ASSERT(NumCols() == dst->NumCols()); + MatrixIndexT num_rows = num_rows_, + num_cols = num_cols_, this_stride = stride_; + Real *this_data = this->data_; + + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride) { + MatrixIndexT index = indexes[r]; + KALDI_ASSERT(index >= -1 && index < dst->NumRows()); + if (index != -1) + cblas_Xaxpy(num_cols, alpha, this_data, 1, dst->RowData(index), 1); + } +} + +template +void MatrixBase::AddToRows(Real alpha, Real *const *dst) const { + MatrixIndexT num_rows = num_rows_, + num_cols = num_cols_, this_stride = stride_; + const Real *this_data = this->data_; + + for (MatrixIndexT r = 0; r < num_rows; r++, this_data += this_stride) { + Real *const dst_data = dst[r]; + if (dst_data != NULL) + cblas_Xaxpy(num_cols, alpha, this_data, 1, dst_data, 1); + } +} + +template +void MatrixBase::Sigmoid(const MatrixBase &src) { + KALDI_ASSERT(SameDim(*this, src)); + + if (num_cols_ == stride_ && src.num_cols_ == src.stride_) { + SubVector src_vec(src.data_, num_rows_ * num_cols_), + dst_vec(this->data_, num_rows_ * num_cols_); + dst_vec.Sigmoid(src_vec); + } else { + for (MatrixIndexT r = 0; r < num_rows_; r++) { + SubVector src_vec(src, r), dest_vec(*this, r); + dest_vec.Sigmoid(src_vec); + } + } +} + +template +void MatrixBase::DiffSigmoid(const MatrixBase &value, + const MatrixBase &diff) { + KALDI_ASSERT(SameDim(*this, value) && SameDim(*this, diff)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, + stride = stride_, value_stride = value.stride_, diff_stride = diff.stride_; + Real *data = data_; + const Real *value_data = value.data_, *diff_data = diff.data_; + for (MatrixIndexT r = 0; r < num_rows; r++) { + for (MatrixIndexT c = 0; c < num_cols; c++) + data[c] = diff_data[c] * value_data[c] * (1.0 - value_data[c]); + data += stride; + value_data += value_stride; + diff_data += diff_stride; + } +} + +template +void MatrixBase::DiffTanh(const MatrixBase &value, + const MatrixBase &diff) { + KALDI_ASSERT(SameDim(*this, value) && SameDim(*this, diff)); + MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, + stride = stride_, value_stride = value.stride_, diff_stride = diff.stride_; + Real *data = data_; + const Real *value_data = value.data_, *diff_data = diff.data_; + for (MatrixIndexT r = 0; r < num_rows; r++) { + for (MatrixIndexT c = 0; c < num_cols; c++) + data[c] = diff_data[c] * (1.0 - (value_data[c] * value_data[c])); + data += stride; + value_data += value_stride; + diff_data += diff_stride; + } +} + + +template +template +void MatrixBase::AddVecToRows(const Real alpha, const VectorBase &v) { + const MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, + stride = stride_; + KALDI_ASSERT(v.Dim() == num_cols); + if(num_cols <= 64) { + Real *data = data_; + const OtherReal *vdata = v.Data(); + for (MatrixIndexT i = 0; i < num_rows; i++, data += stride) { + for (MatrixIndexT j = 0; j < num_cols; j++) + data[j] += alpha * vdata[j]; + } + + } else { + Vector ones(num_rows); + ones.Set(1.0); + this->AddVecVec(alpha, ones, v); + } +} + +template void MatrixBase::AddVecToRows(const float alpha, + const VectorBase &v); +template void MatrixBase::AddVecToRows(const float alpha, + const VectorBase &v); +template void MatrixBase::AddVecToRows(const double alpha, + const VectorBase &v); +template void MatrixBase::AddVecToRows(const double alpha, + const VectorBase &v); + + +template +template +void MatrixBase::AddVecToCols(const Real alpha, const VectorBase &v) { + const MatrixIndexT num_rows = num_rows_, num_cols = num_cols_, + stride = stride_; + KALDI_ASSERT(v.Dim() == num_rows); + + if (num_rows <= 64) { + Real *data = data_; + const OtherReal *vdata = v.Data(); + for (MatrixIndexT i = 0; i < num_rows; i++, data += stride) { + Real to_add = alpha * vdata[i]; + for (MatrixIndexT j = 0; j < num_cols; j++) + data[j] += to_add; + } + + } else { + Vector ones(num_cols); + ones.Set(1.0); + this->AddVecVec(alpha, v, ones); + } +} + +template void MatrixBase::AddVecToCols(const float alpha, + const VectorBase &v); +template void MatrixBase::AddVecToCols(const float alpha, + const VectorBase &v); +template void MatrixBase::AddVecToCols(const double alpha, + const VectorBase &v); +template void MatrixBase::AddVecToCols(const double alpha, + const VectorBase &v); + +//Explicit instantiation of the classes +//Apparently, it seems to be necessary that the instantiation +//happens at the end of the file. Otherwise, not all the member +//functions will get instantiated. + +template class Matrix; +template class Matrix; +template class MatrixBase; +template class MatrixBase; +template class SubMatrix; +template class SubMatrix; + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/kaldi-matrix.h b/speechx/speechx/kaldi/matrix/kaldi-matrix.h new file mode 100644 index 00000000..4387538c --- /dev/null +++ b/speechx/speechx/kaldi/matrix/kaldi-matrix.h @@ -0,0 +1,1122 @@ +// matrix/kaldi-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Haihua Xu +// 2017 Shiyin Kang +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_MATRIX_H_ +#define KALDI_MATRIX_KALDI_MATRIX_H_ 1 + +#include + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// @{ \addtogroup matrix_funcs_scalar + +/// We need to declare this here as it will be a friend function. +/// tr(A B), or tr(A B^T). +template +Real TraceMatMat(const MatrixBase &A, const MatrixBase &B, + MatrixTransposeType trans = kNoTrans); +/// @} + +/// \addtogroup matrix_group +/// @{ + +/// Base class which provides matrix operations not involving resizing +/// or allocation. Classes Matrix and SubMatrix inherit from it and take care +/// of allocation and resizing. +template +class MatrixBase { + public: + // so this child can access protected members of other instances. + friend class Matrix; + // friend declarations for CUDA matrices (see ../cudamatrix/) + friend class CuMatrixBase; + friend class CuMatrix; + friend class CuSubMatrix; + friend class CuPackedMatrix; + friend class PackedMatrix; + friend class SparseMatrix; + friend class SparseMatrix; + friend class SparseMatrix; + + /// Returns number of rows (or zero for empty matrix). + inline MatrixIndexT NumRows() const { return num_rows_; } + + /// Returns number of columns (or zero for empty matrix). + inline MatrixIndexT NumCols() const { return num_cols_; } + + /// Stride (distance in memory between each row). Will be >= NumCols. + inline MatrixIndexT Stride() const { return stride_; } + + /// Returns size in bytes of the data held by the matrix. + size_t SizeInBytes() const { + return static_cast(num_rows_) * static_cast(stride_) * + sizeof(Real); + } + + /// Gives pointer to raw data (const). + inline const Real* Data() const { + return data_; + } + + /// Gives pointer to raw data (non-const). + inline Real* Data() { return data_; } + + /// Returns pointer to data for one row (non-const) + inline Real* RowData(MatrixIndexT i) { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return data_ + i * stride_; + } + + /// Returns pointer to data for one row (const) + inline const Real* RowData(MatrixIndexT i) const { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return data_ + i * stride_; + } + + /// Indexing operator, non-const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline Real& operator() (MatrixIndexT r, MatrixIndexT c) { + KALDI_PARANOID_ASSERT(static_cast(r) < + static_cast(num_rows_) && + static_cast(c) < + static_cast(num_cols_)); + return *(data_ + r * stride_ + c); + } + /// Indexing operator, provided for ease of debugging (gdb doesn't work + /// with parenthesis operator). + Real &Index (MatrixIndexT r, MatrixIndexT c) { return (*this)(r, c); } + + /// Indexing operator, const + /// (only checks sizes if compiled with -DKALDI_PARANOID) + inline const Real operator() (MatrixIndexT r, MatrixIndexT c) const { + KALDI_PARANOID_ASSERT(static_cast(r) < + static_cast(num_rows_) && + static_cast(c) < + static_cast(num_cols_)); + return *(data_ + r * stride_ + c); + } + + /* Basic setting-to-special values functions. */ + + /// Sets matrix to zero. + void SetZero(); + /// Sets all elements to a specific value. + void Set(Real); + /// Sets to zero, except ones along diagonal [for non-square matrices too] + void SetUnit(); + /// Sets to random values of a normal distribution + void SetRandn(); + /// Sets to numbers uniformly distributed on (0, 1) + void SetRandUniform(); + + /* Copying functions. These do not resize the matrix! */ + + + /// Copy given matrix. (no resize is done). + template + void CopyFromMat(const MatrixBase & M, + MatrixTransposeType trans = kNoTrans); + + /// Copy from compressed matrix. + void CopyFromMat(const CompressedMatrix &M); + + /// Copy given spmatrix. (no resize is done). + template + void CopyFromSp(const SpMatrix &M); + + /// Copy given tpmatrix. (no resize is done). + template + void CopyFromTp(const TpMatrix &M, + MatrixTransposeType trans = kNoTrans); + + /// Copy from CUDA matrix. Implemented in ../cudamatrix/cu-matrix.h + template + void CopyFromMat(const CuMatrixBase &M, + MatrixTransposeType trans = kNoTrans); + + /// This function has two modes of operation. If v.Dim() == NumRows() * + /// NumCols(), then treats the vector as a row-by-row concatenation of a + /// matrix and copies to *this. + /// if v.Dim() == NumCols(), it sets each row of *this to a copy of v. + void CopyRowsFromVec(const VectorBase &v); + + /// This version of CopyRowsFromVec is implemented in ../cudamatrix/cu-vector.cc + void CopyRowsFromVec(const CuVectorBase &v); + + template + void CopyRowsFromVec(const VectorBase &v); + + /// Copies vector into matrix, column-by-column. + /// Note that rv.Dim() must either equal NumRows()*NumCols() or NumRows(); + /// this has two modes of operation. + void CopyColsFromVec(const VectorBase &v); + + /// Copy vector into specific column of matrix. + void CopyColFromVec(const VectorBase &v, const MatrixIndexT col); + /// Copy vector into specific row of matrix. + void CopyRowFromVec(const VectorBase &v, const MatrixIndexT row); + /// Copy vector into diagonal of matrix. + void CopyDiagFromVec(const VectorBase &v); + + /* Accessing of sub-parts of the matrix. */ + + /// Return specific row of matrix [const]. + inline const SubVector Row(MatrixIndexT i) const { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return SubVector(data_ + (i * stride_), NumCols()); + } + + /// Return specific row of matrix. + inline SubVector Row(MatrixIndexT i) { + KALDI_ASSERT(static_cast(i) < + static_cast(num_rows_)); + return SubVector(data_ + (i * stride_), NumCols()); + } + + /// Return a sub-part of matrix. + inline SubMatrix Range(const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix(*this, row_offset, num_rows, + col_offset, num_cols); + } + inline SubMatrix RowRange(const MatrixIndexT row_offset, + const MatrixIndexT num_rows) const { + return SubMatrix(*this, row_offset, num_rows, 0, num_cols_); + } + inline SubMatrix ColRange(const MatrixIndexT col_offset, + const MatrixIndexT num_cols) const { + return SubMatrix(*this, 0, num_rows_, col_offset, num_cols); + } + + /* Various special functions. */ + /// Returns sum of all elements in matrix. + Real Sum() const; + /// Returns trace of matrix. + Real Trace(bool check_square = true) const; + // If check_square = true, will crash if matrix is not square. + + /// Returns maximum element of matrix. + Real Max() const; + /// Returns minimum element of matrix. + Real Min() const; + + /// Element by element multiplication with a given matrix. + void MulElements(const MatrixBase &A); + + /// Divide each element by the corresponding element of a given matrix. + void DivElements(const MatrixBase &A); + + /// Multiply each element with a scalar value. + void Scale(Real alpha); + + /// Set, element-by-element, *this = max(*this, A) + void Max(const MatrixBase &A); + /// Set, element-by-element, *this = min(*this, A) + void Min(const MatrixBase &A); + + /// Equivalent to (*this) = (*this) * diag(scale). Scaling + /// each column by a scalar taken from that dimension of the vector. + void MulColsVec(const VectorBase &scale); + + /// Equivalent to (*this) = diag(scale) * (*this). Scaling + /// each row by a scalar taken from that dimension of the vector. + void MulRowsVec(const VectorBase &scale); + + /// Divide each row into src.NumCols() equal groups, and then scale i'th row's + /// j'th group of elements by src(i, j). Requires src.NumRows() == + /// this->NumRows() and this->NumCols() % src.NumCols() == 0. + void MulRowsGroupMat(const MatrixBase &src); + + /// Returns logdet of matrix. + Real LogDet(Real *det_sign = NULL) const; + + /// matrix inverse. + /// if inverse_needed = false, will fill matrix with garbage. + /// (only useful if logdet wanted). + void Invert(Real *log_det = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + /// matrix inverse [double]. + /// if inverse_needed = false, will fill matrix with garbage + /// (only useful if logdet wanted). + /// Does inversion in double precision even if matrix was not double. + void InvertDouble(Real *LogDet = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + + /// Inverts all the elements of the matrix + void InvertElements(); + + /// Transpose the matrix. This one is only + /// applicable to square matrices (the one in the + /// Matrix child class works also for non-square. + void Transpose(); + + /// Copies column r from column indices[r] of src. + /// As a special case, if indexes[i] == -1, sets column i to zero. + /// all elements of "indices" must be in [-1, src.NumCols()-1], + /// and src.NumRows() must equal this.NumRows() + void CopyCols(const MatrixBase &src, + const MatrixIndexT *indices); + + /// Copies row r from row indices[r] of src (does nothing + /// As a special case, if indexes[i] == -1, sets row i to zero. + /// all elements of "indices" must be in [-1, src.NumRows()-1], + /// and src.NumCols() must equal this.NumCols() + void CopyRows(const MatrixBase &src, + const MatrixIndexT *indices); + + /// Add column indices[r] of src to column r. + /// As a special case, if indexes[i] == -1, skip column i + /// indices.size() must equal this->NumCols(), + /// all elements of "reorder" must be in [-1, src.NumCols()-1], + /// and src.NumRows() must equal this.NumRows() + void AddCols(const MatrixBase &src, + const MatrixIndexT *indices); + + /// Copies row r of this matrix from an array of floats at the location given + /// by src[r]. If any src[r] is NULL then this.Row(r) will be set to zero. + /// Note: we are using "pointer to const pointer to const object" for "src", + /// because we may create "src" by calling Data() of const CuArray + void CopyRows(const Real *const *src); + + /// Copies row r of this matrix to the array of floats at the location given + /// by dst[r]. If dst[r] is NULL, does not copy anywhere. Requires that none + /// of the memory regions pointed to by the pointers in "dst" overlap (e.g. + /// none of the pointers should be the same). + void CopyToRows(Real *const *dst) const; + + /// Does for each row r, this.Row(r) += alpha * src.row(indexes[r]). + /// If indexes[r] < 0, does not add anything. all elements of "indexes" must + /// be in [-1, src.NumRows()-1], and src.NumCols() must equal this.NumCols(). + void AddRows(Real alpha, + const MatrixBase &src, + const MatrixIndexT *indexes); + + /// Does for each row r, this.Row(r) += alpha * src[r], treating src[r] as the + /// beginning of a region of memory representing a vector of floats, of the + /// same length as this.NumCols(). If src[r] is NULL, does not add anything. + void AddRows(Real alpha, const Real *const *src); + + /// For each row r of this matrix, adds it (times alpha) to the array of + /// floats at the location given by dst[r]. If dst[r] is NULL, does not do + /// anything for that row. Requires that none of the memory regions pointed + /// to by the pointers in "dst" overlap (e.g. none of the pointers should be + /// the same). + void AddToRows(Real alpha, Real *const *dst) const; + + /// For each row i of *this, adds this->Row(i) to + /// dst->Row(indexes(i)) if indexes(i) >= 0, else do nothing. + /// Requires that all the indexes[i] that are >= 0 + /// be distinct, otherwise the behavior is undefined. + void AddToRows(Real alpha, + const MatrixIndexT *indexes, + MatrixBase *dst) const; + + inline void ApplyPow(Real power) { + this -> Pow(*this, power); + } + + + inline void ApplyPowAbs(Real power, bool include_sign=false) { + this -> PowAbs(*this, power, include_sign); + } + + inline void ApplyHeaviside() { + this -> Heaviside(*this); + } + + inline void ApplyFloor(Real floor_val) { + this -> Floor(*this, floor_val); + } + + inline void ApplyCeiling(Real ceiling_val) { + this -> Ceiling(*this, ceiling_val); + } + + inline void ApplyExp() { + this -> Exp(*this); + } + + inline void ApplyExpSpecial() { + this -> ExpSpecial(*this); + } + + inline void ApplyExpLimited(Real lower_limit, Real upper_limit) { + this -> ExpLimited(*this, lower_limit, upper_limit); + } + + inline void ApplyLog() { + this -> Log(*this); + } + + /// Eigenvalue Decomposition of a square NxN matrix into the form (*this) = P D + /// P^{-1}. Be careful: the relationship of D to the eigenvalues we output is + /// slightly complicated, due to the need for P to be real. In the symmetric + /// case D is diagonal and real, but in + /// the non-symmetric case there may be complex-conjugate pairs of eigenvalues. + /// In this case, for the equation (*this) = P D P^{-1} to hold, D must actually + /// be block diagonal, with 2x2 blocks corresponding to any such pairs. If a + /// pair is lambda +- i*mu, D will have a corresponding 2x2 block + /// [lambda, mu; -mu, lambda]. + /// Note that if the input matrix (*this) is non-invertible, P may not be invertible + /// so in this case instead of the equation (*this) = P D P^{-1} holding, we have + /// instead (*this) P = P D. + /// + /// The non-member function CreateEigenvalueMatrix creates D from eigs_real and eigs_imag. + void Eig(MatrixBase *P, + VectorBase *eigs_real, + VectorBase *eigs_imag) const; + + /// The Power method attempts to take the matrix to a power using a method that + /// works in general for fractional and negative powers. The input matrix must + /// be invertible and have reasonable condition (or we don't guarantee the + /// results. The method is based on the eigenvalue decomposition. It will + /// return false and leave the matrix unchanged, if at entry the matrix had + /// real negative eigenvalues (or if it had zero eigenvalues and the power was + /// negative). + bool Power(Real pow); + + /** Singular value decomposition + Major limitations: + For nonsquare matrices, we assume m>=n (NumRows >= NumCols), and we return + the "skinny" Svd, i.e. the matrix in the middle is diagonal, and the + one on the left is rectangular. + + In Svd, *this = U*diag(S)*Vt. + Null pointers for U and/or Vt at input mean we do not want that output. We + expect that S.Dim() == m, U is either NULL or m by n, + and v is either NULL or n by n. + The singular values are not sorted (use SortSvd for that). */ + void DestructiveSvd(VectorBase *s, MatrixBase *U, + MatrixBase *Vt); // Destroys calling matrix. + + /// Compute SVD (*this) = U diag(s) Vt. Note that the V in the call is already + /// transposed; the normal formulation is U diag(s) V^T. + /// Null pointers for U or V mean we don't want that output (this saves + /// compute). The singular values are not sorted (use SortSvd for that). + void Svd(VectorBase *s, MatrixBase *U, + MatrixBase *Vt) const; + /// Compute SVD but only retain the singular values. + void Svd(VectorBase *s) const { Svd(s, NULL, NULL); } + + + /// Returns smallest singular value. + Real MinSingularValue() const { + Vector tmp(std::min(NumRows(), NumCols())); + Svd(&tmp); + return tmp.Min(); + } + + void TestUninitialized() const; // This function is designed so that if any element + // if the matrix is uninitialized memory, valgrind will complain. + + /// Returns condition number by computing Svd. Works even if cols > rows. + /// Returns infinity if all singular values are zero. + Real Cond() const; + + /// Returns true if matrix is Symmetric. + bool IsSymmetric(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is Diagonal. + bool IsDiagonal(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if the matrix is all zeros, except for ones on diagonal. (it + /// does not have to be square). More specifically, this function returns + /// false if for any i, j, (*this)(i, j) differs by more than cutoff from the + /// expression (i == j ? 1 : 0). + bool IsUnit(Real cutoff = 1.0e-05) const; // replace magic number + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-05) const; // replace magic number + + /// Frobenius norm, which is the sqrt of sum of square elements. Same as Schatten 2-norm, + /// or just "2-norm". + Real FrobeniusNorm() const; + + /// Returns true if ((*this)-other).FrobeniusNorm() + /// <= tol * (*this).FrobeniusNorm(). + bool ApproxEqual(const MatrixBase &other, float tol = 0.01) const; + + /// Tests for exact equality. It's usually preferable to use ApproxEqual. + bool Equal(const MatrixBase &other) const; + + /// largest absolute value. + Real LargestAbsElem() const; // largest absolute value. + + /// Returns log(sum(exp())) without exp overflow + /// If prune > 0.0, it uses a pruning beam, discarding + /// terms less than (max - prune). Note: in future + /// we may change this so that if prune = 0.0, it takes + /// the max, so use -1 if you don't want to prune. + Real LogSumExp(Real prune = -1.0) const; + + /// Apply soft-max to the collection of all elements of the + /// matrix and return normalizer (log sum of exponentials). + Real ApplySoftMax(); + + /// Set each element to the sigmoid of the corresponding element of "src". + void Sigmoid(const MatrixBase &src); + + /// Sets each element to the Heaviside step function (x > 0 ? 1 : 0) of the + /// corresponding element in "src". Note: in general you can make different + /// choices for x = 0, but for now please leave it as it (i.e. returning zero) + /// because it affects the RectifiedLinearComponent in the neural net code. + void Heaviside(const MatrixBase &src); + + void Exp(const MatrixBase &src); + + void Pow(const MatrixBase &src, Real power); + + void Log(const MatrixBase &src); + + /// Apply power to the absolute value of each element. + /// If include_sign is true, the result will be multiplied with + /// the sign of the input value. + /// If the power is negative and the input to the power is zero, + /// The output will be set zero. If include_sign is true, it will + /// multiply the result by the sign of the input. + void PowAbs(const MatrixBase &src, Real power, bool include_sign=false); + + void Floor(const MatrixBase &src, Real floor_val); + + void Ceiling(const MatrixBase &src, Real ceiling_val); + + /// For each element x of the matrix, set it to + /// (x < 0 ? exp(x) : x + 1). This function is used + /// in our RNNLM training. + void ExpSpecial(const MatrixBase &src); + + /// This is equivalent to running: + /// Floor(src, lower_limit); + /// Ceiling(src, upper_limit); + /// Exp(src) + void ExpLimited(const MatrixBase &src, Real lower_limit, Real upper_limit); + + /// Set each element to y = log(1 + exp(x)) + void SoftHinge(const MatrixBase &src); + + /// Apply the function y(i) = (sum_{j = i*G}^{(i+1)*G-1} x_j^(power))^(1 / p). + /// Requires src.NumRows() == this->NumRows() and src.NumCols() % this->NumCols() == 0. + void GroupPnorm(const MatrixBase &src, Real power); + + /// Calculate derivatives for the GroupPnorm function above... + /// if "input" is the input to the GroupPnorm function above (i.e. the "src" variable), + /// and "output" is the result of the computation (i.e. the "this" of that function + /// call), and *this has the same dimension as "input", then it sets each element + /// of *this to the derivative d(output-elem)/d(input-elem) for each element of "input", where + /// "output-elem" is whichever element of output depends on that input element. + void GroupPnormDeriv(const MatrixBase &input, const MatrixBase &output, + Real power); + + /// Apply the function y(i) = (max_{j = i*G}^{(i+1)*G-1} x_j + /// Requires src.NumRows() == this->NumRows() and src.NumCols() % this->NumCols() == 0. + void GroupMax(const MatrixBase &src); + + /// Calculate derivatives for the GroupMax function above, where + /// "input" is the input to the GroupMax function above (i.e. the "src" variable), + /// and "output" is the result of the computation (i.e. the "this" of that function + /// call), and *this must have the same dimension as "input". Each element + /// of *this will be set to 1 if the corresponding input equals the output of + /// the group, and 0 otherwise. The equals the function derivative where it is + /// defined (it's not defined where multiple inputs in the group are equal to the output). + void GroupMaxDeriv(const MatrixBase &input, const MatrixBase &output); + + /// Set each element to the tanh of the corresponding element of "src". + void Tanh(const MatrixBase &src); + + // Function used in backpropagating derivatives of the sigmoid function: + // element-by-element, set *this = diff * value * (1.0 - value). + void DiffSigmoid(const MatrixBase &value, + const MatrixBase &diff); + + // Function used in backpropagating derivatives of the tanh function: + // element-by-element, set *this = diff * (1.0 - value^2). + void DiffTanh(const MatrixBase &value, + const MatrixBase &diff); + + /** Uses Svd to compute the eigenvalue decomposition of a symmetric positive + * semi-definite matrix: (*this) = rP * diag(rS) * rP^T, with rP an + * orthogonal matrix so rP^{-1} = rP^T. Throws exception if input was not + * positive semi-definite (check_thresh controls how stringent the check is; + * set it to 2 to ensure it won't ever complain, but it will zero out negative + * dimensions in your matrix. + * + * Caution: if you want the eigenvalues, it may make more sense to convert to + * SpMatrix and use Eig() function there, which uses eigenvalue decomposition + * directly rather than SVD. + */ + void SymPosSemiDefEig(VectorBase *s, MatrixBase *P, + Real check_thresh = 0.001); + + friend Real kaldi::TraceMatMat(const MatrixBase &A, + const MatrixBase &B, MatrixTransposeType trans); // tr (A B) + + // so it can get around const restrictions on the pointer to data_. + friend class SubMatrix; + + /// Add a scalar to each element + void Add(const Real alpha); + + /// Add a scalar to each diagonal element. + void AddToDiag(const Real alpha); + + /// *this += alpha * a * b^T + template + void AddVecVec(const Real alpha, const VectorBase &a, + const VectorBase &b); + + /// [each row of *this] += alpha * v + template + void AddVecToRows(const Real alpha, const VectorBase &v); + + /// [each col of *this] += alpha * v + template + void AddVecToCols(const Real alpha, const VectorBase &v); + + /// *this += alpha * M [or M^T] + void AddMat(const Real alpha, const MatrixBase &M, + MatrixTransposeType transA = kNoTrans); + + /// *this += alpha * A [or A^T]. + void AddSmat(Real alpha, const SparseMatrix &A, + MatrixTransposeType trans = kNoTrans); + + /// (*this) = alpha * op(A) * B + beta * (*this), where A is sparse. + /// Multiplication of sparse with dense matrix. See also AddMatSmat. + void AddSmatMat(Real alpha, const SparseMatrix &A, + MatrixTransposeType transA, const MatrixBase &B, + Real beta); + + /// (*this) = alpha * A * op(B) + beta * (*this), where B is sparse + /// and op(B) is either B or trans(B) depending on the 'transB' argument. + /// This is multiplication of a dense by a sparse matrix. See also + /// AddSmatMat. + void AddMatSmat(Real alpha, const MatrixBase &A, + const SparseMatrix &B, MatrixTransposeType transB, + Real beta); + + /// *this = beta * *this + alpha * M M^T, for symmetric matrices. It only + /// updates the lower triangle of *this. It will leave the matrix asymmetric; + /// if you need it symmetric as a regular matrix, do CopyLowerToUpper(). + void SymAddMat2(const Real alpha, const MatrixBase &M, + MatrixTransposeType transA, Real beta); + + /// *this = beta * *this + alpha * diag(v) * M [or M^T]. + /// The same as adding M but scaling each row M_i by v(i). + void AddDiagVecMat(const Real alpha, const VectorBase &v, + const MatrixBase &M, MatrixTransposeType transM, + Real beta = 1.0); + + /// *this = beta * *this + alpha * M [or M^T] * diag(v) + /// The same as adding M but scaling each column M_j by v(j). + void AddMatDiagVec(const Real alpha, + const MatrixBase &M, MatrixTransposeType transM, + VectorBase &v, + Real beta = 1.0); + + /// *this = beta * *this + alpha * A .* B (.* element by element multiplication) + void AddMatMatElements(const Real alpha, + const MatrixBase& A, + const MatrixBase& B, + const Real beta); + + /// *this += alpha * S + template + void AddSp(const Real alpha, const SpMatrix &S); + + void AddMatMat(const Real alpha, + const MatrixBase& A, MatrixTransposeType transA, + const MatrixBase& B, MatrixTransposeType transB, + const Real beta); + + /// *this = a * b / c (by element; when c = 0, *this = a) + void SetMatMatDivMat(const MatrixBase& A, + const MatrixBase& B, + const MatrixBase& C); + + /// A version of AddMatMat specialized for when the second argument + /// contains a lot of zeroes. + void AddMatSmat(const Real alpha, + const MatrixBase& A, MatrixTransposeType transA, + const MatrixBase& B, MatrixTransposeType transB, + const Real beta); + + /// A version of AddMatMat specialized for when the first argument + /// contains a lot of zeroes. + void AddSmatMat(const Real alpha, + const MatrixBase& A, MatrixTransposeType transA, + const MatrixBase& B, MatrixTransposeType transB, + const Real beta); + + /// this <-- beta*this + alpha*A*B*C. + void AddMatMatMat(const Real alpha, + const MatrixBase& A, MatrixTransposeType transA, + const MatrixBase& B, MatrixTransposeType transB, + const MatrixBase& C, MatrixTransposeType transC, + const Real beta); + + /// this <-- beta*this + alpha*SpA*B. + // This and the routines below are really + // stubs that need to be made more efficient. + void AddSpMat(const Real alpha, + const SpMatrix& A, + const MatrixBase& B, MatrixTransposeType transB, + const Real beta) { + Matrix M(A); + return AddMatMat(alpha, M, kNoTrans, B, transB, beta); + } + /// this <-- beta*this + alpha*A*B. + void AddTpMat(const Real alpha, + const TpMatrix& A, MatrixTransposeType transA, + const MatrixBase& B, MatrixTransposeType transB, + const Real beta) { + Matrix M(A); + return AddMatMat(alpha, M, transA, B, transB, beta); + } + /// this <-- beta*this + alpha*A*B. + void AddMatSp(const Real alpha, + const MatrixBase& A, MatrixTransposeType transA, + const SpMatrix& B, + const Real beta) { + Matrix M(B); + return AddMatMat(alpha, A, transA, M, kNoTrans, beta); + } + /// this <-- beta*this + alpha*A*B*C. + void AddSpMatSp(const Real alpha, + const SpMatrix &A, + const MatrixBase& B, MatrixTransposeType transB, + const SpMatrix& C, + const Real beta) { + Matrix M(A), N(C); + return AddMatMatMat(alpha, M, kNoTrans, B, transB, N, kNoTrans, beta); + } + /// this <-- beta*this + alpha*A*B. + void AddMatTp(const Real alpha, + const MatrixBase& A, MatrixTransposeType transA, + const TpMatrix& B, MatrixTransposeType transB, + const Real beta) { + Matrix M(B); + return AddMatMat(alpha, A, transA, M, transB, beta); + } + + /// this <-- beta*this + alpha*A*B. + void AddTpTp(const Real alpha, + const TpMatrix& A, MatrixTransposeType transA, + const TpMatrix& B, MatrixTransposeType transB, + const Real beta) { + Matrix M(A), N(B); + return AddMatMat(alpha, M, transA, N, transB, beta); + } + + /// this <-- beta*this + alpha*A*B. + // This one is more efficient, not like the others above. + void AddSpSp(const Real alpha, + const SpMatrix& A, const SpMatrix& B, + const Real beta); + + /// Copy lower triangle to upper triangle (symmetrize) + void CopyLowerToUpper(); + + /// Copy upper triangle to lower triangle (symmetrize) + void CopyUpperToLower(); + + /// This function orthogonalizes the rows of a matrix using the Gram-Schmidt + /// process. It is only applicable if NumRows() <= NumCols(). It will use + /// random number generation to fill in rows with something nonzero, in cases + /// where the original matrix was of deficient row rank. + void OrthogonalizeRows(); + + /// stream read. + /// Use instead of stream<<*this, if you want to add to existing contents. + // Will throw exception on failure. + void Read(std::istream & in, bool binary, bool add = false); + /// write to stream. + void Write(std::ostream & out, bool binary) const; + + // Below is internal methods for Svd, user does not have to know about this. +#if !defined(HAVE_ATLAS) && !defined(USE_KALDI_SVD) + // protected: + // Should be protected but used directly in testing routine. + // destroys *this! + void LapackGesvd(VectorBase *s, MatrixBase *U, + MatrixBase *Vt); +#else + protected: + // destroys *this! + bool JamaSvd(VectorBase *s, MatrixBase *U, + MatrixBase *V); + +#endif + protected: + + /// Initializer, callable only from child. + explicit MatrixBase(Real *data, MatrixIndexT cols, MatrixIndexT rows, MatrixIndexT stride) : + data_(data), num_cols_(cols), num_rows_(rows), stride_(stride) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + /// Initializer, callable only from child. + /// Empty initializer, for un-initialized matrix. + explicit MatrixBase(): data_(NULL) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + + // Make sure pointers to MatrixBase cannot be deleted. + ~MatrixBase() { } + + /// A workaround that allows SubMatrix to get a pointer to non-const data + /// for const Matrix. Unfortunately C++ does not allow us to declare a + /// "public const" inheritance or anything like that, so it would require + /// a lot of work to make the SubMatrix class totally const-correct-- + /// we would have to override many of the Matrix functions. + inline Real* Data_workaround() const { + return data_; + } + + /// data memory area + Real* data_; + + /// these attributes store the real matrix size as it is stored in memory + /// including memalignment + MatrixIndexT num_cols_; /// < Number of columns + MatrixIndexT num_rows_; /// < Number of rows + /** True number of columns for the internal matrix. This number may differ + * from num_cols_ as memory alignment might be used. */ + MatrixIndexT stride_; + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(MatrixBase); +}; + +/// A class for storing matrices. +template +class Matrix : public MatrixBase { + public: + + /// Empty constructor. + Matrix(); + + /// Basic constructor. + Matrix(const MatrixIndexT r, const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride): + MatrixBase() { Resize(r, c, resize_type, stride_type); } + + /// Copy constructor from CUDA matrix + /// This is defined in ../cudamatrix/cu-matrix.h + template + explicit Matrix(const CuMatrixBase &cu, + MatrixTransposeType trans = kNoTrans); + + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Matrix *other); + + /// Defined in ../cudamatrix/cu-matrix.cc + void Swap(CuMatrix *mat); + + /// Constructor from any MatrixBase. Can also copy with transpose. + /// Allocates new memory. + explicit Matrix(const MatrixBase & M, + MatrixTransposeType trans = kNoTrans); + + /// Same as above, but need to avoid default copy constructor. + Matrix(const Matrix & M); // (cannot make explicit) + + /// Copy constructor: as above, but from another type. + template + explicit Matrix(const MatrixBase & M, + MatrixTransposeType trans = kNoTrans); + + /// Copy constructor taking SpMatrix... + /// It is symmetric, so no option for transpose, and NumRows == Cols + template + explicit Matrix(const SpMatrix & M) : MatrixBase() { + Resize(M.NumRows(), M.NumRows(), kUndefined); + this->CopyFromSp(M); + } + + /// Constructor from CompressedMatrix + explicit Matrix(const CompressedMatrix &C); + + /// Copy constructor taking TpMatrix... + template + explicit Matrix(const TpMatrix & M, + MatrixTransposeType trans = kNoTrans) : MatrixBase() { + if (trans == kNoTrans) { + Resize(M.NumRows(), M.NumCols(), kUndefined); + this->CopyFromTp(M); + } else { + Resize(M.NumCols(), M.NumRows(), kUndefined); + this->CopyFromTp(M, kTrans); + } + } + + /// read from stream. + // Unlike one in base, allows resizing. + void Read(std::istream & in, bool binary, bool add = false); + + /// Remove a specified row. + void RemoveRow(MatrixIndexT i); + + /// Transpose the matrix. Works for non-square + /// matrices as well as square ones. + void Transpose(); + + /// Distructor to free matrices. + ~Matrix() { Destroy(); } + + /// Sets matrix to a specified size (zero is OK as long as both r and c are + /// zero). The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// + /// You can set stride_type to kStrideEqualNumCols to force the stride + /// to equal the number of columns; by default it is set so that the stride + /// in bytes is a multiple of 16. + /// + /// This function takes time proportional to the number of data elements. + void Resize(const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride); + + /// Assignment operator that takes MatrixBase. + Matrix &operator = (const MatrixBase &other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } + + /// Assignment operator. Needed for inclusion in std::vector. + Matrix &operator = (const Matrix &other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } + + + private: + /// Deallocates memory and sets to empty matrix (dimension 0, 0). + void Destroy(); + + /// Init assumes the current class contents are invalid (i.e. junk or have + /// already been freed), and it sets the matrix to newly allocated memory with + /// the specified number of rows and columns. r == c == 0 is acceptable. The data + /// memory contents will be undefined. + void Init(const MatrixIndexT r, + const MatrixIndexT c, + const MatrixStrideType stride_type); + +}; +/// @} end "addtogroup matrix_group" + +/// \addtogroup matrix_funcs_io +/// @{ + +/// A structure containing the HTK header. +/// [TODO: change the style of the variables to Kaldi-compliant] +struct HtkHeader { + /// Number of samples. + int32 mNSamples; + /// Sample period. + int32 mSamplePeriod; + /// Sample size + int16 mSampleSize; + /// Sample kind. + uint16 mSampleKind; +}; + +// Read HTK formatted features from file into matrix. +template +bool ReadHtk(std::istream &is, Matrix *M, HtkHeader *header_ptr); + +// Write (HTK format) features to file from matrix. +template +bool WriteHtk(std::ostream &os, const MatrixBase &M, HtkHeader htk_hdr); + +// Write (CMUSphinx format) features to file from matrix. +template +bool WriteSphinx(std::ostream &os, const MatrixBase &M); + +/// @} end of "addtogroup matrix_funcs_io" + +/** + Sub-matrix representation. + Can work with sub-parts of a matrix using this class. + Note that SubMatrix is not very const-correct-- it allows you to + change the contents of a const Matrix. Be careful! +*/ + +template +class SubMatrix : public MatrixBase { + public: + // Initialize a SubMatrix from part of a matrix; this is + // a bit like A(b:c, d:e) in Matlab. + // This initializer is against the proper semantics of "const", since + // SubMatrix can change its contents. It would be hard to implement + // a "const-safe" version of this class. + SubMatrix(const MatrixBase& T, + const MatrixIndexT ro, // row offset, 0 < ro < NumRows() + const MatrixIndexT r, // number of rows, r > 0 + const MatrixIndexT co, // column offset, 0 < co < NumCols() + const MatrixIndexT c); // number of columns, c > 0 + + // This initializer is mostly intended for use in CuMatrix and related + // classes. Be careful! + SubMatrix(Real *data, + MatrixIndexT num_rows, + MatrixIndexT num_cols, + MatrixIndexT stride); + + ~SubMatrix() {} + + /// This type of constructor is needed for Range() to work [in Matrix base + /// class]. Cannot make it explicit. + SubMatrix (const SubMatrix &other): + MatrixBase (other.data_, other.num_cols_, other.num_rows_, + other.stride_) {} + + private: + /// Disallow assignment. + SubMatrix &operator = (const SubMatrix &other); +}; +/// @} End of "addtogroup matrix_funcs_io". + +/// \addtogroup matrix_funcs_scalar +/// @{ + +// Some declarations. These are traces of products. + + +template +bool ApproxEqual(const MatrixBase &A, + const MatrixBase &B, Real tol = 0.01) { + return A.ApproxEqual(B, tol); +} + +template +inline void AssertEqual(const MatrixBase &A, const MatrixBase &B, + float tol = 0.01) { + KALDI_ASSERT(A.ApproxEqual(B, tol)); +} + +/// Returns trace of matrix. +template +double TraceMat(const MatrixBase &A) { return A.Trace(); } + + +/// Returns tr(A B C) +template +Real TraceMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC); + +/// Returns tr(A B C D) +template +Real TraceMatMatMatMat(const MatrixBase &A, MatrixTransposeType transA, + const MatrixBase &B, MatrixTransposeType transB, + const MatrixBase &C, MatrixTransposeType transC, + const MatrixBase &D, MatrixTransposeType transD); + +/// @} end "addtogroup matrix_funcs_scalar" + + +/// \addtogroup matrix_funcs_misc +/// @{ + + +/// Function to ensure that SVD is sorted. This function is made as generic as +/// possible, to be applicable to other types of problems. s->Dim() should be +/// the same as U->NumCols(), and we sort s from greatest to least absolute +/// value (if sort_on_absolute_value == true) or greatest to least value +/// otherwise, moving the columns of U, if it exists, and the rows of Vt, if it +/// exists, around in the same way. Note: the "absolute value" part won't matter +/// if this is an actual SVD, since singular values are non-negative. +template void SortSvd(VectorBase *s, MatrixBase *U, + MatrixBase* Vt = NULL, + bool sort_on_absolute_value = true); + +/// Creates the eigenvalue matrix D that is part of the decomposition used Matrix::Eig. +/// D will be block-diagonal with blocks of size 1 (for real eigenvalues) or 2x2 +/// for complex pairs. If a complex pair is lambda +- i*mu, D will have a corresponding +/// 2x2 block [lambda, mu; -mu, lambda]. +/// This function will throw if any complex eigenvalues are not in complex conjugate +/// pairs (or the members of such pairs are not consecutively numbered). +template +void CreateEigenvalueMatrix(const VectorBase &real, const VectorBase &imag, + MatrixBase *D); + +/// The following function is used in Matrix::Power, and separately tested, so we +/// declare it here mainly for the testing code to see. It takes a complex value to +/// a power using a method that will work for noninteger powers (but will fail if the +/// complex value is real and negative). +template +bool AttemptComplexPower(Real *x_re, Real *x_im, Real power); + + + +/// @} end of addtogroup matrix_funcs_misc + +/// \addtogroup matrix_funcs_io +/// @{ +template +std::ostream & operator << (std::ostream & Out, const MatrixBase & M); + +template +std::istream & operator >> (std::istream & In, MatrixBase & M); + +// The Matrix read allows resizing, so we override the MatrixBase one. +template +std::istream & operator >> (std::istream & In, Matrix & M); + + +template +bool SameDim(const MatrixBase &M, const MatrixBase &N) { + return (M.NumRows() == N.NumRows() && M.NumCols() == N.NumCols()); +} + +/// @} end of \addtogroup matrix_funcs_io + + +} // namespace kaldi + + + +// we need to include the implementation and some +// template specializations. +#include "matrix/kaldi-matrix-inl.h" + + +#endif // KALDI_MATRIX_KALDI_MATRIX_H_ diff --git a/speechx/speechx/kaldi/matrix/kaldi-vector-inl.h b/speechx/speechx/kaldi/matrix/kaldi-vector-inl.h new file mode 100644 index 00000000..c3a4f52f --- /dev/null +++ b/speechx/speechx/kaldi/matrix/kaldi-vector-inl.h @@ -0,0 +1,58 @@ +// matrix/kaldi-vector-inl.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; +// Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// This is an internal header file, included by other library headers. +// You should not attempt to use it directly. + +#ifndef KALDI_MATRIX_KALDI_VECTOR_INL_H_ +#define KALDI_MATRIX_KALDI_VECTOR_INL_H_ 1 + +namespace kaldi { + +template +std::ostream & operator << (std::ostream &os, const VectorBase &rv) { + rv.Write(os, false); + return os; +} + +template +std::istream &operator >> (std::istream &is, VectorBase &rv) { + rv.Read(is, false); + return is; +} + +template +std::istream &operator >> (std::istream &is, Vector &rv) { + rv.Read(is, false); + return is; +} + +template<> +template<> +void VectorBase::AddVec(const float alpha, const VectorBase &rv); + +template<> +template<> +void VectorBase::AddVec(const double alpha, + const VectorBase &rv); + +} // namespace kaldi + +#endif // KALDI_MATRIX_KALDI_VECTOR_INL_H_ diff --git a/speechx/speechx/kaldi/matrix/kaldi-vector.cc b/speechx/speechx/kaldi/matrix/kaldi-vector.cc new file mode 100644 index 00000000..ccc7e89b --- /dev/null +++ b/speechx/speechx/kaldi/matrix/kaldi-vector.cc @@ -0,0 +1,1355 @@ +// matrix/kaldi-vector.cc + +// Copyright 2009-2011 Microsoft Corporation; Lukas Burget; +// Saarland University; Go Vivace Inc.; Ariya Rastrow; +// Petr Schwarz; Yanmin Qian; Jan Silovsky; +// Haihua Xu; Wei Shi +// 2015 Guoguo Chen +// 2017 Daniel Galvez +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "matrix/cblas-wrappers.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/sp-matrix.h" +#include "matrix/sparse-matrix.h" + +namespace kaldi { + +template +Real VecVec(const VectorBase &a, + const VectorBase &b) { + MatrixIndexT adim = a.Dim(); + KALDI_ASSERT(adim == b.Dim()); + return cblas_Xdot(adim, a.Data(), 1, b.Data(), 1); +} + +template +float VecVec<>(const VectorBase &a, + const VectorBase &b); +template +double VecVec<>(const VectorBase &a, + const VectorBase &b); + +template +Real VecVec(const VectorBase &ra, + const VectorBase &rb) { + MatrixIndexT adim = ra.Dim(); + KALDI_ASSERT(adim == rb.Dim()); + const Real *a_data = ra.Data(); + const OtherReal *b_data = rb.Data(); + Real sum = 0.0; + for (MatrixIndexT i = 0; i < adim; i++) + sum += a_data[i]*b_data[i]; + return sum; +} + +// instantiate the template above. +template +float VecVec<>(const VectorBase &ra, + const VectorBase &rb); +template +double VecVec<>(const VectorBase &ra, + const VectorBase &rb); + + +template<> +template<> +void VectorBase::AddVec(const float alpha, + const VectorBase &v) { + KALDI_ASSERT(dim_ == v.dim_); + KALDI_ASSERT(&v != this); + cblas_Xaxpy(dim_, alpha, v.Data(), 1, data_, 1); +} + +template<> +template<> +void VectorBase::AddVec(const double alpha, + const VectorBase &v) { + KALDI_ASSERT(dim_ == v.dim_); + KALDI_ASSERT(&v != this); + cblas_Xaxpy(dim_, alpha, v.Data(), 1, data_, 1); +} + +template +void VectorBase::AddMatVec(const Real alpha, + const MatrixBase &M, + MatrixTransposeType trans, + const VectorBase &v, + const Real beta) { + KALDI_ASSERT((trans == kNoTrans && M.NumCols() == v.dim_ && M.NumRows() == dim_) + || (trans == kTrans && M.NumRows() == v.dim_ && M.NumCols() == dim_)); + KALDI_ASSERT(&v != this); + cblas_Xgemv(trans, M.NumRows(), M.NumCols(), alpha, M.Data(), M.Stride(), + v.Data(), 1, beta, data_, 1); +} + +template +void VectorBase::AddMatSvec(const Real alpha, + const MatrixBase &M, + MatrixTransposeType trans, + const VectorBase &v, + const Real beta) { + KALDI_ASSERT((trans == kNoTrans && M.NumCols() == v.dim_ && M.NumRows() == dim_) + || (trans == kTrans && M.NumRows() == v.dim_ && M.NumCols() == dim_)); + KALDI_ASSERT(&v != this); + Xgemv_sparsevec(trans, M.NumRows(), M.NumCols(), alpha, M.Data(), M.Stride(), + v.Data(), 1, beta, data_, 1); + return; + /* + MatrixIndexT this_dim = this->dim_, v_dim = v.dim_, + M_stride = M.Stride(); + Real *this_data = this->data_; + const Real *M_data = M.Data(), *v_data = v.data_; + if (beta != 1.0) this->Scale(beta); + if (trans == kNoTrans) { + for (MatrixIndexT i = 0; i < v_dim; i++) { + Real v_i = v_data[i]; + if (v_i == 0.0) continue; + // Add to *this, the i'th column of the Matrix, times v_i. + cblas_Xaxpy(this_dim, v_i * alpha, M_data + i, M_stride, this_data, 1); + } + } else { // The transposed case is slightly more efficient, I guess. + for (MatrixIndexT i = 0; i < v_dim; i++) { + Real v_i = v.data_[i]; + if (v_i == 0.0) continue; + // Add to *this, the i'th row of the Matrix, times v_i. + cblas_Xaxpy(this_dim, v_i * alpha, + M_data + (i * M_stride), 1, this_data, 1); + } + }*/ +} + +template +void VectorBase::AddSpVec(const Real alpha, + const SpMatrix &M, + const VectorBase &v, + const Real beta) { + KALDI_ASSERT(M.NumRows() == v.dim_ && dim_ == v.dim_); + KALDI_ASSERT(&v != this); + cblas_Xspmv(alpha, M.NumRows(), M.Data(), v.Data(), 1, beta, data_, 1); +} + + +template +void VectorBase::MulTp(const TpMatrix &M, + const MatrixTransposeType trans) { + KALDI_ASSERT(M.NumRows() == dim_); + cblas_Xtpmv(trans,M.Data(),M.NumRows(),data_,1); +} + +template +void VectorBase::Solve(const TpMatrix &M, + const MatrixTransposeType trans) { + KALDI_ASSERT(M.NumRows() == dim_); + cblas_Xtpsv(trans, M.Data(), M.NumRows(), data_, 1); +} + + +template +inline void Vector::Init(const MatrixIndexT dim) { + KALDI_ASSERT(dim >= 0); + if (dim == 0) { + this->dim_ = 0; + this->data_ = NULL; + return; + } + MatrixIndexT size; + void *data; + void *free_data; + + size = dim * sizeof(Real); + + if ((data = KALDI_MEMALIGN(16, size, &free_data)) != NULL) { + this->data_ = static_cast (data); + this->dim_ = dim; + } else { + throw std::bad_alloc(); + } +} + + +template +void Vector::Resize(const MatrixIndexT dim, MatrixResizeType resize_type) { + + // the next block uses recursion to handle what we have to do if + // resize_type == kCopyData. + if (resize_type == kCopyData) { + if (this->data_ == NULL || dim == 0) resize_type = kSetZero; // nothing to copy. + else if (this->dim_ == dim) { return; } // nothing to do. + else { + // set tmp to a vector of the desired size. + Vector tmp(dim, kUndefined); + if (dim > this->dim_) { + memcpy(tmp.data_, this->data_, sizeof(Real)*this->dim_); + memset(tmp.data_+this->dim_, 0, sizeof(Real)*(dim-this->dim_)); + } else { + memcpy(tmp.data_, this->data_, sizeof(Real)*dim); + } + tmp.Swap(this); + // and now let tmp go out of scope, deleting what was in *this. + return; + } + } + // At this point, resize_type == kSetZero or kUndefined. + + if (this->data_ != NULL) { + if (this->dim_ == dim) { + if (resize_type == kSetZero) this->SetZero(); + return; + } else { + Destroy(); + } + } + Init(dim); + if (resize_type == kSetZero) this->SetZero(); +} + + +/// Copy data from another vector +template +void VectorBase::CopyFromVec(const VectorBase &v) { + KALDI_ASSERT(Dim() == v.Dim()); + if (data_ != v.data_) { + std::memcpy(this->data_, v.data_, dim_ * sizeof(Real)); + } +} + +template +template +void VectorBase::CopyFromPacked(const PackedMatrix& M) { + SubVector v(M); + this->CopyFromVec(v); +} +// instantiate the template. +template void VectorBase::CopyFromPacked(const PackedMatrix &other); +template void VectorBase::CopyFromPacked(const PackedMatrix &other); +template void VectorBase::CopyFromPacked(const PackedMatrix &other); +template void VectorBase::CopyFromPacked(const PackedMatrix &other); + +/// Load data into the vector +template +void VectorBase::CopyFromPtr(const Real *data, MatrixIndexT sz) { + KALDI_ASSERT(dim_ == sz); + std::memcpy(this->data_, data, Dim() * sizeof(Real)); +} + +template +template +void VectorBase::CopyFromVec(const VectorBase &other) { + KALDI_ASSERT(dim_ == other.Dim()); + Real * __restrict__ ptr = data_; + const OtherReal * __restrict__ other_ptr = other.Data(); + for (MatrixIndexT i = 0; i < dim_; i++) + ptr[i] = other_ptr[i]; +} + +template void VectorBase::CopyFromVec(const VectorBase &other); +template void VectorBase::CopyFromVec(const VectorBase &other); + +// Remove element from the vector. The vector is not reallocated +template +void Vector::RemoveElement(MatrixIndexT i) { + KALDI_ASSERT(i < this->dim_ && "Access out of vector"); + for (MatrixIndexT j = i + 1; j < this->dim_; j++) + this->data_[j-1] = this->data_[j]; + this->dim_--; +} + + +/// Deallocates memory and sets object to empty vector. +template +void Vector::Destroy() { + /// we need to free the data block if it was defined + if (this->data_ != NULL) + KALDI_MEMALIGN_FREE(this->data_); + this->data_ = NULL; + this->dim_ = 0; +} + +template +void VectorBase::SetZero() { + std::memset(data_, 0, dim_ * sizeof(Real)); +} + +template +bool VectorBase::IsZero(Real cutoff) const { + Real abs_max = 0.0; + for (MatrixIndexT i = 0; i < Dim(); i++) + abs_max = std::max(std::abs(data_[i]), abs_max); + return (abs_max <= cutoff); +} + +template +void VectorBase::SetRandn() { + kaldi::RandomState rstate; + MatrixIndexT last = (Dim() % 2 == 1) ? Dim() - 1 : Dim(); + for (MatrixIndexT i = 0; i < last; i += 2) { + kaldi::RandGauss2(data_ + i, data_ + i + 1, &rstate); + } + if (Dim() != last) data_[last] = static_cast(kaldi::RandGauss(&rstate)); +} + +template +void VectorBase::SetRandUniform() { + kaldi::RandomState rstate; + for (MatrixIndexT i = 0; i < Dim(); i++) { + *(data_+i) = RandUniform(&rstate); + } +} + +template +MatrixIndexT VectorBase::RandCategorical() const { + kaldi::RandomState rstate; + Real sum = this->Sum(); + KALDI_ASSERT(this->Min() >= 0.0 && sum > 0.0); + Real r = RandUniform(&rstate) * sum; + Real *data = this->data_; + MatrixIndexT dim = this->dim_; + Real running_sum = 0.0; + for (MatrixIndexT i = 0; i < dim; i++) { + running_sum += data[i]; + if (r < running_sum) return i; + } + return dim_ - 1; // Should only happen if RandUniform() + // returns exactly 1, or due to roundoff. +} + +template +void VectorBase::Set(Real f) { + // Why not use memset here? + // The basic unit of memset is a byte. + // If f != 0 and sizeof(Real) > 1, then we cannot use memset. + if (f == 0) { + this->SetZero(); // calls std::memset + } else { + for (MatrixIndexT i = 0; i < dim_; i++) { data_[i] = f; } + } +} + +template +void VectorBase::CopyRowsFromMat(const MatrixBase &mat) { + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + + Real *inc_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(); + + if (mat.Stride() == mat.NumCols()) { + memcpy(inc_data, mat.Data(), cols*rows*sizeof(Real)); + } else { + for (MatrixIndexT i = 0; i < rows; i++) { + // copy the data to the propper position + memcpy(inc_data, mat.RowData(i), cols * sizeof(Real)); + // set new copy position + inc_data += cols; + } + } +} + +template +template +void VectorBase::CopyRowsFromMat(const MatrixBase &mat) { + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + Real *vec_data = data_; + const MatrixIndexT cols = mat.NumCols(), + rows = mat.NumRows(); + + for (MatrixIndexT i = 0; i < rows; i++) { + const OtherReal *mat_row = mat.RowData(i); + for (MatrixIndexT j = 0; j < cols; j++) { + vec_data[j] = static_cast(mat_row[j]); + } + vec_data += cols; + } +} + +template +void VectorBase::CopyRowsFromMat(const MatrixBase &mat); +template +void VectorBase::CopyRowsFromMat(const MatrixBase &mat); + + +template +void VectorBase::CopyColsFromMat(const MatrixBase &mat) { + KALDI_ASSERT(dim_ == mat.NumCols() * mat.NumRows()); + + Real* inc_data = data_; + const MatrixIndexT cols = mat.NumCols(), rows = mat.NumRows(), stride = mat.Stride(); + const Real *mat_inc_data = mat.Data(); + + for (MatrixIndexT i = 0; i < cols; i++) { + for (MatrixIndexT j = 0; j < rows; j++) { + inc_data[j] = mat_inc_data[j*stride]; + } + mat_inc_data++; + inc_data += rows; + } +} + +template +void VectorBase::CopyRowFromMat(const MatrixBase &mat, MatrixIndexT row) { + KALDI_ASSERT(row < mat.NumRows()); + KALDI_ASSERT(dim_ == mat.NumCols()); + const Real *mat_row = mat.RowData(row); + memcpy(data_, mat_row, sizeof(Real)*dim_); +} + +template +template +void VectorBase::CopyRowFromMat(const MatrixBase &mat, MatrixIndexT row) { + KALDI_ASSERT(row < mat.NumRows()); + KALDI_ASSERT(dim_ == mat.NumCols()); + const OtherReal *mat_row = mat.RowData(row); + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = static_cast(mat_row[i]); +} + +template +void VectorBase::CopyRowFromMat(const MatrixBase &mat, MatrixIndexT row); +template +void VectorBase::CopyRowFromMat(const MatrixBase &mat, MatrixIndexT row); + +template +template +void VectorBase::CopyRowFromSp(const SpMatrix &sp, MatrixIndexT row) { + KALDI_ASSERT(row < sp.NumRows()); + KALDI_ASSERT(dim_ == sp.NumCols()); + + const OtherReal *sp_data = sp.Data(); + + sp_data += (row*(row+1)) / 2; // takes us to beginning of this row. + MatrixIndexT i; + for (i = 0; i < row; i++) // copy consecutive elements. + data_[i] = static_cast(*(sp_data++)); + for(; i < dim_; ++i, sp_data += i) + data_[i] = static_cast(*sp_data); +} + +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT row); +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT row); +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT row); +template +void VectorBase::CopyRowFromSp(const SpMatrix &mat, MatrixIndexT row); + + +#ifdef HAVE_MKL +template<> +void VectorBase::Pow(const VectorBase &v, float power) { + vsPowx(dim_, data_, power, v.data_); +} +template<> +void VectorBase::Pow(const VectorBase &v, double power) { + vdPowx(dim_, data_, power, v.data_); +} +#else + +// takes elements to a power. Does not check output. +template +void VectorBase::Pow(const VectorBase &v, Real power) { + KALDI_ASSERT(dim_ == v.dim_); + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = pow(v.data_[i], power); + } +} +#endif + +// takes absolute value of the elements to a power. +// Throws exception if could not (but only for power != 1 and power != 2). +template +void VectorBase::ApplyPowAbs(Real power, bool include_sign) { + if (power == 1.0) + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * std::abs(data_[i]); + if (power == 2.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * data_[i] * data_[i]; + } else if (power == 0.5) { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * std::sqrt(std::abs(data_[i])); + } + } else if (power < 0.0) { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = (data_[i] == 0.0 ? 0.0 : pow(std::abs(data_[i]), power)); + data_[i] *= (include_sign && data_[i] < 0 ? -1 : 1); + if (data_[i] == HUGE_VAL) { // HUGE_VAL is what errno returns on error. + KALDI_ERR << "Could not raise element " << i << "to power " + << power << ": returned value = " << data_[i]; + } + } + } else { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = (include_sign && data_[i] < 0 ? -1 : 1) * pow(std::abs(data_[i]), power); + if (data_[i] == HUGE_VAL) { // HUGE_VAL is what errno returns on error. + KALDI_ERR << "Could not raise element " << i << "to power " + << power << ": returned value = " << data_[i]; + } + } + } +} + +// Computes the p-th norm. Throws exception if could not. +template +Real VectorBase::Norm(Real p) const { + KALDI_ASSERT(p >= 0.0); + Real sum = 0.0; + if (p == 0.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + if (data_[i] != 0.0) sum += 1.0; + return sum; + } else if (p == 1.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + sum += std::abs(data_[i]); + return sum; + } else if (p == 2.0) { + for (MatrixIndexT i = 0; i < dim_; i++) + sum += data_[i] * data_[i]; + return std::sqrt(sum); + } else if (p == std::numeric_limits::infinity()){ + for (MatrixIndexT i = 0; i < dim_; i++) + sum = std::max(sum, std::abs(data_[i])); + return sum; + } else { + Real tmp; + bool ok = true; + for (MatrixIndexT i = 0; i < dim_; i++) { + tmp = pow(std::abs(data_[i]), p); + if (tmp == HUGE_VAL) // HUGE_VAL is what pow returns on error. + ok = false; + sum += tmp; + } + tmp = pow(sum, static_cast(1.0/p)); + KALDI_ASSERT(tmp != HUGE_VAL); // should not happen here. + if (ok) { + return tmp; + } else { + Real maximum = this->Max(), minimum = this->Min(), + max_abs = std::max(maximum, -minimum); + KALDI_ASSERT(max_abs > 0); // Or should not have reached here. + Vector tmp(*this); + tmp.Scale(1.0 / max_abs); + return tmp.Norm(p) * max_abs; + } + } +} + +template +bool VectorBase::ApproxEqual(const VectorBase &other, float tol) const { + if (dim_ != other.dim_) KALDI_ERR << "ApproxEqual: size mismatch " + << dim_ << " vs. " << other.dim_; + KALDI_ASSERT(tol >= 0.0); + if (tol != 0.0) { + Vector tmp(*this); + tmp.AddVec(-1.0, other); + return (tmp.Norm(2.0) <= static_cast(tol) * this->Norm(2.0)); + } else { // Test for exact equality. + const Real *data = data_; + const Real *other_data = other.data_; + for (MatrixIndexT dim = dim_, i = 0; i < dim; i++) + if (data[i] != other_data[i]) return false; + return true; + } +} + +template +Real VectorBase::Max() const { + Real ans = - std::numeric_limits::infinity(); + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 > ans || a2 > ans || a3 > ans || a4 > ans) { + Real b1 = (a1 > a2 ? a1 : a2), b2 = (a3 > a4 ? a3 : a4); + if (b1 > ans) ans = b1; + if (b2 > ans) ans = b2; + } + } + for (; i < dim; i++) + if (data[i] > ans) ans = data[i]; + return ans; +} + +template +Real VectorBase::Max(MatrixIndexT *index_out) const { + if (dim_ == 0) KALDI_ERR << "Empty vector"; + Real ans = - std::numeric_limits::infinity(); + MatrixIndexT index = 0; + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 > ans || a2 > ans || a3 > ans || a4 > ans) { + if (a1 > ans) { ans = a1; index = i; } + if (a2 > ans) { ans = a2; index = i + 1; } + if (a3 > ans) { ans = a3; index = i + 2; } + if (a4 > ans) { ans = a4; index = i + 3; } + } + } + for (; i < dim; i++) + if (data[i] > ans) { ans = data[i]; index = i; } + *index_out = index; + return ans; +} + +template +Real VectorBase::Min() const { + Real ans = std::numeric_limits::infinity(); + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 < ans || a2 < ans || a3 < ans || a4 < ans) { + Real b1 = (a1 < a2 ? a1 : a2), b2 = (a3 < a4 ? a3 : a4); + if (b1 < ans) ans = b1; + if (b2 < ans) ans = b2; + } + } + for (; i < dim; i++) + if (data[i] < ans) ans = data[i]; + return ans; +} + +template +Real VectorBase::Min(MatrixIndexT *index_out) const { + if (dim_ == 0) KALDI_ERR << "Empty vector"; + Real ans = std::numeric_limits::infinity(); + MatrixIndexT index = 0; + const Real *data = data_; + MatrixIndexT i, dim = dim_; + for (i = 0; i + 4 <= dim; i += 4) { + Real a1 = data[i], a2 = data[i+1], a3 = data[i+2], a4 = data[i+3]; + if (a1 < ans || a2 < ans || a3 < ans || a4 < ans) { + if (a1 < ans) { ans = a1; index = i; } + if (a2 < ans) { ans = a2; index = i + 1; } + if (a3 < ans) { ans = a3; index = i + 2; } + if (a4 < ans) { ans = a4; index = i + 3; } + } + } + for (; i < dim; i++) + if (data[i] < ans) { ans = data[i]; index = i; } + *index_out = index; + return ans; +} + + +template +template +void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col) { + KALDI_ASSERT(col < mat.NumCols()); + KALDI_ASSERT(dim_ == mat.NumRows()); + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = mat(i, col); + // can't do this very efficiently so don't really bother. could improve this though. +} +// instantiate the template above. +template +void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col); +template +void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col); +template +void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col); +template +void VectorBase::CopyColFromMat(const MatrixBase &mat, MatrixIndexT col); + +template +void VectorBase::CopyDiagFromMat(const MatrixBase &M) { + KALDI_ASSERT(dim_ == std::min(M.NumRows(), M.NumCols())); + cblas_Xcopy(dim_, M.Data(), M.Stride() + 1, data_, 1); +} + +template +void VectorBase::CopyDiagFromPacked(const PackedMatrix &M) { + KALDI_ASSERT(dim_ == M.NumCols()); + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] = M(i, i); + // could make this more efficient. +} + +template +Real VectorBase::Sum() const { + // Do a dot-product with a size-1 array with a stride of 0 to + // implement sum. This allows us to access SIMD operations in a + // cross-platform way via your BLAS library. + Real one(1); + return cblas_Xdot(dim_, data_, 1, &one, 0); +} + +template +Real VectorBase::SumLog() const { + double sum_log = 0.0; + double prod = 1.0; + for (MatrixIndexT i = 0; i < dim_; i++) { + prod *= data_[i]; + // Possible future work (arnab): change these magic values to pre-defined + // constants + if (prod < 1.0e-10 || prod > 1.0e+10) { + sum_log += Log(prod); + prod = 1.0; + } + } + if (prod != 1.0) sum_log += Log(prod); + return sum_log; +} + +template +void VectorBase::AddRowSumMat(Real alpha, const MatrixBase &M, Real beta) { + KALDI_ASSERT(dim_ == M.NumCols()); + MatrixIndexT num_rows = M.NumRows(), stride = M.Stride(), dim = dim_; + Real *data = data_; + + // implement the function according to a dimension cutoff for computation efficiency + if (num_rows <= 64) { + cblas_Xscal(dim, beta, data, 1); + const Real *m_data = M.Data(); + for (MatrixIndexT i = 0; i < num_rows; i++, m_data += stride) + cblas_Xaxpy(dim, alpha, m_data, 1, data, 1); + + } else { + Vector ones(M.NumRows()); + ones.Set(1.0); + this->AddMatVec(alpha, M, kTrans, ones, beta); + } +} + +template +void VectorBase::AddColSumMat(Real alpha, const MatrixBase &M, Real beta) { + KALDI_ASSERT(dim_ == M.NumRows()); + MatrixIndexT num_cols = M.NumCols(); + + // implement the function according to a dimension cutoff for computation efficiency + if (num_cols <= 64) { + for (MatrixIndexT i = 0; i < dim_; i++) { + double sum = 0.0; + const Real *src = M.RowData(i); + for (MatrixIndexT j = 0; j < num_cols; j++) + sum += src[j]; + data_[i] = alpha * sum + beta * data_[i]; + } + } else { + Vector ones(M.NumCols()); + ones.Set(1.0); + this->AddMatVec(alpha, M, kNoTrans, ones, beta); + } +} + +template +Real VectorBase::LogSumExp(Real prune) const { + Real sum; + if (sizeof(sum) == 8) sum = kLogZeroDouble; + else sum = kLogZeroFloat; + Real max_elem = Max(), cutoff; + if (sizeof(Real) == 4) cutoff = max_elem + kMinLogDiffFloat; + else cutoff = max_elem + kMinLogDiffDouble; + if (prune > 0.0 && max_elem - prune > cutoff) // explicit pruning... + cutoff = max_elem - prune; + + double sum_relto_max_elem = 0.0; + + for (MatrixIndexT i = 0; i < dim_; i++) { + BaseFloat f = data_[i]; + if (f >= cutoff) + sum_relto_max_elem += Exp(f - max_elem); + } + return max_elem + Log(sum_relto_max_elem); +} + +template +void VectorBase::InvertElements() { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = static_cast(1 / data_[i]); + } +} + +template +void VectorBase::ApplyLog() { + for (MatrixIndexT i = 0; i < dim_; i++) { + if (data_[i] < 0.0) + KALDI_ERR << "Trying to take log of a negative number."; + data_[i] = Log(data_[i]); + } +} + +template +void VectorBase::ApplyLogAndCopy(const VectorBase &v) { + KALDI_ASSERT(dim_ == v.Dim()); + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = Log(v(i)); + } +} + +template +void VectorBase::ApplyExp() { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = Exp(data_[i]); + } +} + +template +void VectorBase::ApplyAbs() { + for (MatrixIndexT i = 0; i < dim_; i++) { data_[i] = std::abs(data_[i]); } +} + +template +void VectorBase::Floor(const VectorBase &v, Real floor_val, MatrixIndexT *floored_count) { + KALDI_ASSERT(dim_ == v.dim_); + if (floored_count == nullptr) { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = std::max(v.data_[i], floor_val); + } + } else { + MatrixIndexT num_floored = 0; + for (MatrixIndexT i = 0; i < dim_; i++) { + if (v.data_[i] < floor_val) { + data_[i] = floor_val; + num_floored++; + } else { + data_[i] = v.data_[i]; + } + } + *floored_count = num_floored; + } +} + +template +void VectorBase::Ceiling(const VectorBase &v, Real ceil_val, MatrixIndexT *ceiled_count) { + KALDI_ASSERT(dim_ == v.dim_); + if (ceiled_count == nullptr) { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = std::min(v.data_[i], ceil_val); + } + } else { + MatrixIndexT num_changed = 0; + for (MatrixIndexT i = 0; i < dim_; i++) { + if (v.data_[i] > ceil_val) { + data_[i] = ceil_val; + num_changed++; + } else { + data_[i] = v.data_[i]; + } + } + *ceiled_count = num_changed; + } +} + +template +MatrixIndexT VectorBase::ApplyFloor(const VectorBase &floor_vec) { + KALDI_ASSERT(floor_vec.Dim() == dim_); + MatrixIndexT num_floored = 0; + for (MatrixIndexT i = 0; i < dim_; i++) { + if (data_[i] < floor_vec(i)) { + data_[i] = floor_vec(i); + num_floored++; + } + } + return num_floored; +} + +template +Real VectorBase::ApplySoftMax() { + Real max = this->Max(), sum = 0.0; + for (MatrixIndexT i = 0; i < dim_; i++) { + sum += (data_[i] = Exp(data_[i] - max)); + } + this->Scale(1.0 / sum); + return max + Log(sum); +} + +template +Real VectorBase::ApplyLogSoftMax() { + Real max = this->Max(), sum = 0.0; + for (MatrixIndexT i = 0; i < dim_; i++) { + sum += Exp((data_[i] -= max)); + } + sum = Log(sum); + this->Add(-1.0 * sum); + return max + sum; +} + +#ifdef HAVE_MKL +template<> +void VectorBase::Tanh(const VectorBase &src) { + KALDI_ASSERT(dim_ == src.dim_); + vsTanh(dim_, src.data_, data_); +} +template<> +void VectorBase::Tanh(const VectorBase &src) { + KALDI_ASSERT(dim_ == src.dim_); + vdTanh(dim_, src.data_, data_); +} +#else +template +void VectorBase::Tanh(const VectorBase &src) { + KALDI_ASSERT(dim_ == src.dim_); + for (MatrixIndexT i = 0; i < dim_; i++) { + Real x = src.data_[i]; + if (x > 0.0) { + Real inv_expx = Exp(-x); + x = -1.0 + 2.0 / (1.0 + inv_expx * inv_expx); + } else { + Real expx = Exp(x); + x = 1.0 - 2.0 / (1.0 + expx * expx); + } + data_[i] = x; + } +} +#endif + +#ifdef HAVE_MKL +// Implementing sigmoid based on tanh. +template<> +void VectorBase::Sigmoid(const VectorBase &src) { + KALDI_ASSERT(dim_ == src.dim_); + this->CopyFromVec(src); + this->Scale(0.5); + vsTanh(dim_, data_, data_); + this->Add(1.0); + this->Scale(0.5); +} +template<> +void VectorBase::Sigmoid(const VectorBase &src) { + KALDI_ASSERT(dim_ == src.dim_); + this->CopyFromVec(src); + this->Scale(0.5); + vdTanh(dim_, data_, data_); + this->Add(1.0); + this->Scale(0.5); +} +#else +template +void VectorBase::Sigmoid(const VectorBase &src) { + KALDI_ASSERT(dim_ == src.dim_); + for (MatrixIndexT i = 0; i < dim_; i++) { + Real x = src.data_[i]; + // We aim to avoid floating-point overflow here. + if (x > 0.0) { + x = 1.0 / (1.0 + Exp(-x)); + } else { + Real ex = Exp(x); + x = ex / (ex + 1.0); + } + data_[i] = x; + } +} +#endif + + +template +void VectorBase::Add(Real c) { + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] += c; + } +} + +template +void VectorBase::Scale(Real alpha) { + cblas_Xscal(dim_, alpha, data_, 1); +} + +template +void VectorBase::MulElements(const VectorBase &v) { + KALDI_ASSERT(dim_ == v.dim_); + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] *= v.data_[i]; + } +} + +template // Set each element to y = (x == orig ? changed : x). +void VectorBase::ReplaceValue(Real orig, Real changed) { + Real *data = data_; + for (MatrixIndexT i = 0; i < dim_; i++) + if (data[i] == orig) data[i] = changed; +} + + +template +template +void VectorBase::MulElements(const VectorBase &v) { + KALDI_ASSERT(dim_ == v.Dim()); + const OtherReal *other_ptr = v.Data(); + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] *= other_ptr[i]; + } +} +// instantiate template. +template +void VectorBase::MulElements(const VectorBase &v); +template +void VectorBase::MulElements(const VectorBase &v); + + +template +void VectorBase::AddVecVec(Real alpha, const VectorBase &v, + const VectorBase &r, Real beta) { + KALDI_ASSERT(v.data_ != this->data_ && r.data_ != this->data_); + // We pretend that v is a band-diagonal matrix. + KALDI_ASSERT(dim_ == v.dim_ && dim_ == r.dim_); + cblas_Xgbmv(kNoTrans, dim_, dim_, 0, 0, alpha, v.data_, 1, + r.data_, 1, beta, this->data_, 1); +} + + +template +void VectorBase::DivElements(const VectorBase &v) { + KALDI_ASSERT(dim_ == v.dim_); + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] /= v.data_[i]; + } +} + +template +template +void VectorBase::DivElements(const VectorBase &v) { + KALDI_ASSERT(dim_ == v.Dim()); + const OtherReal *other_ptr = v.Data(); + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] /= other_ptr[i]; + } +} +// instantiate template. +template +void VectorBase::DivElements(const VectorBase &v); +template +void VectorBase::DivElements(const VectorBase &v); + +template +void VectorBase::AddVecDivVec(Real alpha, const VectorBase &v, + const VectorBase &rr, Real beta) { + KALDI_ASSERT((dim_ == v.dim_ && dim_ == rr.dim_)); + for (MatrixIndexT i = 0; i < dim_; i++) { + data_[i] = alpha * v.data_[i]/rr.data_[i] + beta * data_[i] ; + } +} + +template +template +void VectorBase::AddVec(const Real alpha, const VectorBase &v) { + KALDI_ASSERT(dim_ == v.dim_); + // remove __restrict__ if it causes compilation problems. + Real *__restrict__ data = data_; + OtherReal *__restrict__ other_data = v.data_; + MatrixIndexT dim = dim_; + if (alpha != 1.0) + for (MatrixIndexT i = 0; i < dim; i++) + data[i] += alpha * other_data[i]; + else + for (MatrixIndexT i = 0; i < dim; i++) + data[i] += other_data[i]; +} + +template +void VectorBase::AddVec(const float alpha, const VectorBase &v); +template +void VectorBase::AddVec(const double alpha, const VectorBase &v); + +template +template +void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { + KALDI_ASSERT(dim_ == v.dim_); + // remove __restrict__ if it causes compilation problems. + Real *__restrict__ data = data_; + OtherReal *__restrict__ other_data = v.data_; + MatrixIndexT dim = dim_; + if (alpha != 1.0) + for (MatrixIndexT i = 0; i < dim; i++) + data[i] += alpha * other_data[i] * other_data[i]; + else + for (MatrixIndexT i = 0; i < dim; i++) + data[i] += other_data[i] * other_data[i]; +} + +template +void VectorBase::AddVec2(const float alpha, const VectorBase &v); +template +void VectorBase::AddVec2(const double alpha, const VectorBase &v); + + +template +void VectorBase::Read(std::istream &is, bool binary, bool add) { + if (add) { + Vector tmp(Dim()); + tmp.Read(is, binary, false); // read without adding. + if (this->Dim() != tmp.Dim()) { + KALDI_ERR << "VectorBase::Read, size mismatch " << this->Dim()<<" vs. "<AddVec(1.0, tmp); + return; + } // now assume add == false. + + // In order to avoid rewriting this, we just declare a Vector and + // use it to read the data, then copy. + Vector tmp; + tmp.Read(is, binary, false); + if (tmp.Dim() != Dim()) + KALDI_ERR << "VectorBase::Read, size mismatch " + << Dim() << " vs. " << tmp.Dim(); + CopyFromVec(tmp); +} + + +template +void Vector::Read(std::istream &is, bool binary, bool add) { + if (add) { + Vector tmp(this->Dim()); + tmp.Read(is, binary, false); // read without adding. + if (this->Dim() == 0) this->Resize(tmp.Dim()); + if (this->Dim() != tmp.Dim()) { + KALDI_ERR << "Vector::Read, adding but dimensions mismatch " + << this->Dim() << " vs. " << tmp.Dim(); + } + this->AddVec(1.0, tmp); + return; + } // now assume add == false. + + std::ostringstream specific_error; + MatrixIndexT pos_at_start = is.tellg(); + + if (binary) { + int peekval = Peek(is, binary); + const char *my_token = (sizeof(Real) == 4 ? "FV" : "DV"); + char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); + if (peekval == other_token_start) { // need to instantiate the other type to read it. + typedef typename OtherReal::Real OtherType; // if Real == float, OtherType == double, and vice versa. + Vector other(this->Dim()); + other.Read(is, binary, false); // add is false at this point. + if (this->Dim() != other.Dim()) this->Resize(other.Dim()); + this->CopyFromVec(other); + return; + } + std::string token; + ReadToken(is, binary, &token); + if (token != my_token) { + if (token.length() > 20) token = token.substr(0, 17) + "..."; + specific_error << ": Expected token " << my_token << ", got " << token; + goto bad; + } + int32 size; + ReadBasicType(is, binary, &size); // throws on error. + if ((MatrixIndexT)size != this->Dim()) this->Resize(size); + if (size > 0) + is.read(reinterpret_cast(this->data_), sizeof(Real)*size); + if (is.fail()) { + specific_error << "Error reading vector data (binary mode); truncated " + "stream? (size = " << size << ")"; + goto bad; + } + return; + } else { // Text mode reading; format is " [ 1.1 2.0 3.4 ]\n" + std::string s; + is >> s; + // if ((s.compare("DV") == 0) || (s.compare("FV") == 0)) { // Back compatibility. + // is >> s; // get dimension + // is >> s; // get "[" + // } + if (is.fail()) { specific_error << "EOF while trying to read vector."; goto bad; } + if (s.compare("[]") == 0) { Resize(0); return; } // tolerate this variant. + if (s.compare("[")) { + if (s.length() > 20) s = s.substr(0, 17) + "..."; + specific_error << "Expected \"[\" but got " << s; + goto bad; + } + std::vector data; + while (1) { + int i = is.peek(); + if (i == '-' || (i >= '0' && i <= '9')) { // common cases first. + Real r; + is >> r; + if (is.fail()) { specific_error << "Failed to read number."; goto bad; } + if (! std::isspace(is.peek()) && is.peek() != ']') { + specific_error << "Expected whitespace after number."; goto bad; + } + data.push_back(r); + // But don't eat whitespace... we want to check that it's not newlines + // which would be valid only for a matrix. + } else if (i == ' ' || i == '\t') { + is.get(); + } else if (i == ']') { + is.get(); // eat the ']' + this->Resize(data.size()); + for (size_t j = 0; j < data.size(); j++) + this->data_[j] = data[j]; + i = is.peek(); + if (static_cast(i) == '\r') { + is.get(); + is.get(); // get \r\n (must eat what we wrote) + } else if (static_cast(i) == '\n') { is.get(); } // get \n (must eat what we wrote) + if (is.fail()) { + KALDI_WARN << "After end of vector data, read error."; + // we got the data we needed, so just warn for this error. + } + return; // success. + } else if (i == -1) { + specific_error << "EOF while reading vector data."; + goto bad; + } else if (i == '\n' || i == '\r') { + specific_error << "Newline found while reading vector (maybe it's a matrix?)"; + goto bad; + } else { + is >> s; // read string. + if (!KALDI_STRCASECMP(s.c_str(), "inf") || + !KALDI_STRCASECMP(s.c_str(), "infinity")) { + data.push_back(std::numeric_limits::infinity()); + KALDI_WARN << "Reading infinite value into vector."; + } else if (!KALDI_STRCASECMP(s.c_str(), "nan")) { + data.push_back(std::numeric_limits::quiet_NaN()); + KALDI_WARN << "Reading NaN value into vector."; + } else { + if (s.length() > 20) s = s.substr(0, 17) + "..."; + specific_error << "Expecting numeric vector data, got " << s; + goto bad; + } + } + } + } + // we never reach this line (the while loop returns directly). +bad: + KALDI_ERR << "Failed to read vector from stream. " << specific_error.str() + << " File position at start is " + << pos_at_start<<", currently "< +void VectorBase::Write(std::ostream & os, bool binary) const { + if (!os.good()) { + KALDI_ERR << "Failed to write vector to stream: stream not good"; + } + if (binary) { + std::string my_token = (sizeof(Real) == 4 ? "FV" : "DV"); + WriteToken(os, binary, my_token); + + int32 size = Dim(); // make the size 32-bit on disk. + KALDI_ASSERT(Dim() == (MatrixIndexT) size); + WriteBasicType(os, binary, size); + os.write(reinterpret_cast(Data()), sizeof(Real) * size); + } else { + os << " [ "; + for (MatrixIndexT i = 0; i < Dim(); i++) + os << (*this)(i) << " "; + os << "]\n"; + } + if (!os.good()) + KALDI_ERR << "Failed to write vector to stream"; +} + + +template +void VectorBase::AddVec2(const Real alpha, const VectorBase &v) { + KALDI_ASSERT(dim_ == v.dim_); + for (MatrixIndexT i = 0; i < dim_; i++) + data_[i] += alpha * v.data_[i] * v.data_[i]; +} + +// this <-- beta*this + alpha*M*v. +template +void VectorBase::AddTpVec(const Real alpha, const TpMatrix &M, + const MatrixTransposeType trans, + const VectorBase &v, + const Real beta) { + KALDI_ASSERT(dim_ == v.dim_ && dim_ == M.NumRows()); + if (beta == 0.0) { + if (&v != this) CopyFromVec(v); + MulTp(M, trans); + if (alpha != 1.0) Scale(alpha); + } else { + Vector tmp(v); + tmp.MulTp(M, trans); + if (beta != 1.0) Scale(beta); // *this <-- beta * *this + AddVec(alpha, tmp); // *this += alpha * M * v + } +} + +template +Real VecMatVec(const VectorBase &v1, const MatrixBase &M, + const VectorBase &v2) { + KALDI_ASSERT(v1.Dim() == M.NumRows() && v2.Dim() == M.NumCols()); + Vector vtmp(M.NumRows()); + vtmp.AddMatVec(1.0, M, kNoTrans, v2, 0.0); + return VecVec(v1, vtmp); +} + +template +float VecMatVec(const VectorBase &v1, const MatrixBase &M, + const VectorBase &v2); +template +double VecMatVec(const VectorBase &v1, const MatrixBase &M, + const VectorBase &v2); + +template +void Vector::Swap(Vector *other) { + std::swap(this->data_, other->data_); + std::swap(this->dim_, other->dim_); +} + + +template +void VectorBase::AddDiagMat2( + Real alpha, const MatrixBase &M, + MatrixTransposeType trans, Real beta) { + if (trans == kNoTrans) { + KALDI_ASSERT(this->dim_ == M.NumRows()); + MatrixIndexT rows = this->dim_, cols = M.NumCols(), + mat_stride = M.Stride(); + Real *data = this->data_; + const Real *mat_data = M.Data(); + for (MatrixIndexT i = 0; i < rows; i++, mat_data += mat_stride, data++) + *data = beta * *data + alpha * cblas_Xdot(cols,mat_data,1,mat_data,1); + } else { + KALDI_ASSERT(this->dim_ == M.NumCols()); + MatrixIndexT rows = M.NumRows(), cols = this->dim_, + mat_stride = M.Stride(); + Real *data = this->data_; + const Real *mat_data = M.Data(); + for (MatrixIndexT i = 0; i < cols; i++, mat_data++, data++) + *data = beta * *data + alpha * cblas_Xdot(rows, mat_data, mat_stride, + mat_data, mat_stride); + } +} + +template +void VectorBase::AddDiagMatMat( + Real alpha, + const MatrixBase &M, MatrixTransposeType transM, + const MatrixBase &N, MatrixTransposeType transN, + Real beta) { + MatrixIndexT dim = this->dim_, + M_col_dim = (transM == kTrans ? M.NumRows() : M.NumCols()), + N_row_dim = (transN == kTrans ? N.NumCols() : N.NumRows()); + KALDI_ASSERT(M_col_dim == N_row_dim); // this is the dimension we sum over + MatrixIndexT M_row_stride = M.Stride(), M_col_stride = 1; + if (transM == kTrans) std::swap(M_row_stride, M_col_stride); + MatrixIndexT N_row_stride = N.Stride(), N_col_stride = 1; + if (transN == kTrans) std::swap(N_row_stride, N_col_stride); + + Real *data = this->data_; + const Real *Mdata = M.Data(), *Ndata = N.Data(); + for (MatrixIndexT i = 0; i < dim; i++, Mdata += M_row_stride, Ndata += N_col_stride, data++) { + *data = beta * *data + alpha * cblas_Xdot(M_col_dim, Mdata, M_col_stride, Ndata, N_row_stride); + } +} + + +template class Vector; +template class Vector; +template class VectorBase; +template class VectorBase; + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/kaldi-vector.h b/speechx/speechx/kaldi/matrix/kaldi-vector.h new file mode 100644 index 00000000..2a032354 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/kaldi-vector.h @@ -0,0 +1,612 @@ +// matrix/kaldi-vector.h + +// Copyright 2009-2012 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University (Author: Arnab Ghoshal); +// Ariya Rastrow; Petr Schwarz; Yanmin Qian; +// Karel Vesely; Go Vivace Inc.; Arnab Ghoshal +// Wei Shi; +// 2015 Guoguo Chen +// 2017 Daniel Galvez +// 2019 Yiwen Shao + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_KALDI_VECTOR_H_ +#define KALDI_MATRIX_KALDI_VECTOR_H_ 1 + +#include "matrix/matrix-common.h" + +namespace kaldi { + +/// \addtogroup matrix_group +/// @{ + +/// Provides a vector abstraction class. +/// This class provides a way to work with vectors in kaldi. +/// It encapsulates basic operations and memory optimizations. +template +class VectorBase { + public: + /// Set vector to all zeros. + void SetZero(); + + /// Returns true if matrix is all zeros. + bool IsZero(Real cutoff = 1.0e-06) const; // replace magic number + + /// Set all members of a vector to a specified value. + void Set(Real f); + + /// Set vector to random normally-distributed noise. + void SetRandn(); + + /// Sets to numbers uniformly distributed on (0,1) + void SetRandUniform(); + + /// This function returns a random index into this vector, + /// chosen with probability proportional to the corresponding + /// element. Requires that this->Min() >= 0 and this->Sum() > 0. + MatrixIndexT RandCategorical() const; + + /// Returns the dimension of the vector. + inline MatrixIndexT Dim() const { return dim_; } + + /// Returns the size in memory of the vector, in bytes. + inline MatrixIndexT SizeInBytes() const { return (dim_*sizeof(Real)); } + + /// Returns a pointer to the start of the vector's data. + inline Real* Data() { return data_; } + + /// Returns a pointer to the start of the vector's data (const). + inline const Real* Data() const { return data_; } + + /// Indexing operator (const). + inline Real operator() (MatrixIndexT i) const { + KALDI_PARANOID_ASSERT(static_cast(i) < + static_cast(dim_)); + return *(data_ + i); + } + + /// Indexing operator (non-const). + inline Real & operator() (MatrixIndexT i) { + KALDI_PARANOID_ASSERT(static_cast(i) < + static_cast(dim_)); + return *(data_ + i); + } + + /** @brief Returns a sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + SubVector Range(const MatrixIndexT o, const MatrixIndexT l) { + return SubVector(*this, o, l); + } + + /** @brief Returns a const sub-vector of a vector (a range of elements). + * @param o [in] Origin, 0 < o < Dim() + * @param l [in] Length 0 < l < Dim()-o + * @return A SubVector object that aliases the data of the Vector object. + * See @c SubVector class for details */ + const SubVector Range(const MatrixIndexT o, + const MatrixIndexT l) const { + return SubVector(*this, o, l); + } + + /// Copy data from another vector (must match own size). + void CopyFromVec(const VectorBase &v); + + /// Copy data from a SpMatrix or TpMatrix (must match own size). + template + void CopyFromPacked(const PackedMatrix &M); + + /// Copy data from another vector of different type (double vs. float) + template + void CopyFromVec(const VectorBase &v); + + /// Copy from CuVector. This is defined in ../cudamatrix/cu-vector.h + template + void CopyFromVec(const CuVectorBase &v); + + /// Applies floor to all elements. Returns number of elements + /// floored in floored_count if it is non-null. + void Floor(const VectorBase &v, Real floor_val, MatrixIndexT *floored_count = nullptr); + + /// Applies ceiling to all elements. Returns number of elements + /// changed in ceiled_count if it is non-null. + void Ceiling(const VectorBase &v, Real ceil_val, MatrixIndexT *ceiled_count = nullptr); + + void Pow(const VectorBase &v, Real power); + + /// Apply natural log to all elements. Throw if any element of + /// the vector is negative (but doesn't complain about zero; the + /// log will be -infinity + void ApplyLog(); + + /// Apply natural log to another vector and put result in *this. + void ApplyLogAndCopy(const VectorBase &v); + + /// Apply exponential to each value in vector. + void ApplyExp(); + + /// Take absolute value of each of the elements + void ApplyAbs(); + + /// Applies floor to all elements. Returns number of elements + /// floored in floored_count if it is non-null. + inline void ApplyFloor(Real floor_val, MatrixIndexT *floored_count = nullptr) { + this->Floor(*this, floor_val, floored_count); + }; + + /// Applies ceiling to all elements. Returns number of elements + /// changed in ceiled_count if it is non-null. + inline void ApplyCeiling(Real ceil_val, MatrixIndexT *ceiled_count = nullptr) { + this->Ceiling(*this, ceil_val, ceiled_count); + }; + + /// Applies floor to all elements. Returns number of elements floored. + MatrixIndexT ApplyFloor(const VectorBase &floor_vec); + + /// Apply soft-max to vector and return normalizer (log sum of exponentials). + /// This is the same as: \f$ x(i) = exp(x(i)) / \sum_i exp(x(i)) \f$ + Real ApplySoftMax(); + + /// Applies log soft-max to vector and returns normalizer (log sum of + /// exponentials). + /// This is the same as: \f$ x(i) = x(i) - log(\sum_i exp(x(i))) \f$ + Real ApplyLogSoftMax(); + + /// Sets each element of *this to the tanh of the corresponding element of "src". + void Tanh(const VectorBase &src); + + /// Sets each element of *this to the sigmoid function of the corresponding + /// element of "src". + void Sigmoid(const VectorBase &src); + + /// Take all elements of vector to a power. + inline void ApplyPow(Real power) { + this->Pow(*this, power); + }; + + /// Take the absolute value of all elements of a vector to a power. + /// Include the sign of the input element if include_sign == true. + /// If power is negative and the input value is zero, the output is set zero. + void ApplyPowAbs(Real power, bool include_sign=false); + + /// Compute the p-th norm of the vector. + Real Norm(Real p) const; + + /// Returns true if ((*this)-other).Norm(2.0) <= tol * (*this).Norm(2.0). + bool ApproxEqual(const VectorBase &other, float tol = 0.01) const; + + /// Invert all elements. + void InvertElements(); + + /// Add vector : *this = *this + alpha * rv (with casting between floats and + /// doubles) + template + void AddVec(const Real alpha, const VectorBase &v); + + /// Add vector : *this = *this + alpha * rv^2 [element-wise squaring]. + void AddVec2(const Real alpha, const VectorBase &v); + + /// Add vector : *this = *this + alpha * rv^2 [element-wise squaring], + /// with casting between floats and doubles. + template + void AddVec2(const Real alpha, const VectorBase &v); + + /// Add matrix times vector : this <-- beta*this + alpha*M*v. + /// Calls BLAS GEMV. + void AddMatVec(const Real alpha, const MatrixBase &M, + const MatrixTransposeType trans, const VectorBase &v, + const Real beta); // **beta previously defaulted to 0.0** + + /// This is as AddMatVec, except optimized for where v contains a lot + /// of zeros. + void AddMatSvec(const Real alpha, const MatrixBase &M, + const MatrixTransposeType trans, const VectorBase &v, + const Real beta); // **beta previously defaulted to 0.0** + + + /// Add symmetric positive definite matrix times vector: + /// this <-- beta*this + alpha*M*v. Calls BLAS SPMV. + void AddSpVec(const Real alpha, const SpMatrix &M, + const VectorBase &v, const Real beta); // **beta previously defaulted to 0.0** + + /// Add triangular matrix times vector: this <-- beta*this + alpha*M*v. + /// Works even if rv == *this. + void AddTpVec(const Real alpha, const TpMatrix &M, + const MatrixTransposeType trans, const VectorBase &v, + const Real beta); // **beta previously defaulted to 0.0** + + /// Set each element to y = (x == orig ? changed : x). + void ReplaceValue(Real orig, Real changed); + + /// Multiply element-by-element by another vector. + void MulElements(const VectorBase &v); + /// Multiply element-by-element by another vector of different type. + template + void MulElements(const VectorBase &v); + + /// Divide element-by-element by a vector. + void DivElements(const VectorBase &v); + /// Divide element-by-element by a vector of different type. + template + void DivElements(const VectorBase &v); + + /// Add a constant to each element of a vector. + void Add(Real c); + + /// Add element-by-element product of vectors: + // this <-- alpha * v .* r + beta*this . + void AddVecVec(Real alpha, const VectorBase &v, + const VectorBase &r, Real beta); + + /// Add element-by-element quotient of two vectors. + /// this <---- alpha*v/r + beta*this + void AddVecDivVec(Real alpha, const VectorBase &v, + const VectorBase &r, Real beta); + + /// Multiplies all elements by this constant. + void Scale(Real alpha); + + /// Multiplies this vector by lower-triangular matrix: *this <-- *this *M + void MulTp(const TpMatrix &M, const MatrixTransposeType trans); + + /// If trans == kNoTrans, solves M x = b, where b is the value of *this at input + /// and x is the value of *this at output. + /// If trans == kTrans, solves M' x = b. + /// Does not test for M being singular or near-singular, so test it before + /// calling this routine. + void Solve(const TpMatrix &M, const MatrixTransposeType trans); + + /// Performs a row stack of the matrix M + void CopyRowsFromMat(const MatrixBase &M); + template + void CopyRowsFromMat(const MatrixBase &M); + + /// The following is implemented in ../cudamatrix/cu-matrix.cc + void CopyRowsFromMat(const CuMatrixBase &M); + + /// Performs a column stack of the matrix M + void CopyColsFromMat(const MatrixBase &M); + + /// Extracts a row of the matrix M. Could also do this with + /// this->Copy(M[row]). + void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); + /// Extracts a row of the matrix M with type conversion. + template + void CopyRowFromMat(const MatrixBase &M, MatrixIndexT row); + + /// Extracts a row of the symmetric matrix S. + template + void CopyRowFromSp(const SpMatrix &S, MatrixIndexT row); + + /// Extracts a column of the matrix M. + template + void CopyColFromMat(const MatrixBase &M , MatrixIndexT col); + + /// Extracts the diagonal of the matrix M. + void CopyDiagFromMat(const MatrixBase &M); + + /// Extracts the diagonal of a packed matrix M; works for Sp or Tp. + void CopyDiagFromPacked(const PackedMatrix &M); + + + /// Extracts the diagonal of a symmetric matrix. + inline void CopyDiagFromSp(const SpMatrix &M) { CopyDiagFromPacked(M); } + + /// Extracts the diagonal of a triangular matrix. + inline void CopyDiagFromTp(const TpMatrix &M) { CopyDiagFromPacked(M); } + + /// Returns the maximum value of any element, or -infinity for the empty vector. + Real Max() const; + + /// Returns the maximum value of any element, and the associated index. + /// Error if vector is empty. + Real Max(MatrixIndexT *index) const; + + /// Returns the minimum value of any element, or +infinity for the empty vector. + Real Min() const; + + /// Returns the minimum value of any element, and the associated index. + /// Error if vector is empty. + Real Min(MatrixIndexT *index) const; + + /// Returns sum of the elements + Real Sum() const; + + /// Returns sum of the logs of the elements. More efficient than + /// just taking log of each. Will return NaN if any elements are + /// negative. + Real SumLog() const; + + /// Does *this = alpha * (sum of rows of M) + beta * *this. + void AddRowSumMat(Real alpha, const MatrixBase &M, Real beta = 1.0); + + /// Does *this = alpha * (sum of columns of M) + beta * *this. + void AddColSumMat(Real alpha, const MatrixBase &M, Real beta = 1.0); + + /// Add the diagonal of a matrix times itself: + /// *this = diag(M M^T) + beta * *this (if trans == kNoTrans), or + /// *this = diag(M^T M) + beta * *this (if trans == kTrans). + void AddDiagMat2(Real alpha, const MatrixBase &M, + MatrixTransposeType trans = kNoTrans, Real beta = 1.0); + + /// Add the diagonal of a matrix product: *this = diag(M N), assuming the + /// "trans" arguments are both kNoTrans; for transpose arguments, it behaves + /// as you would expect. + void AddDiagMatMat(Real alpha, const MatrixBase &M, MatrixTransposeType transM, + const MatrixBase &N, MatrixTransposeType transN, + Real beta = 1.0); + + /// Returns log(sum(exp())) without exp overflow + /// If prune > 0.0, ignores terms less than the max - prune. + /// [Note: in future, if prune = 0.0, it will take the max. + /// For now, use -1 if you don't want it to prune.] + Real LogSumExp(Real prune = -1.0) const; + + /// Reads from C++ stream (option to add to existing contents). + /// Throws exception on failure + void Read(std::istream &in, bool binary, bool add = false); + + /// Writes to C++ stream (option to write in binary). + void Write(std::ostream &Out, bool binary) const; + + friend class VectorBase; + friend class VectorBase; + friend class CuVectorBase; + friend class CuVector; + protected: + /// Destructor; does not deallocate memory, this is handled by child classes. + /// This destructor is protected so this object can only be + /// deleted via a child. + ~VectorBase() {} + + /// Empty initializer, corresponds to vector of zero size. + explicit VectorBase(): data_(NULL), dim_(0) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } + +// Took this out since it is not currently used, and it is possible to create +// objects where the allocated memory is not the same size as dim_ : Arnab +// /// Initializer from a pointer and a size; keeps the pointer internally +// /// (ownership or non-ownership depends on the child class). +// explicit VectorBase(Real* data, MatrixIndexT dim) +// : data_(data), dim_(dim) {} + + // Arnab : made this protected since it is unsafe too. + /// Load data into the vector: sz must match own size. + void CopyFromPtr(const Real* Data, MatrixIndexT sz); + + /// data memory area + Real* data_; + /// dimension of vector + MatrixIndexT dim_; + KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); +}; // class VectorBase + +/** @brief A class representing a vector. + * + * This class provides a way to work with vectors in kaldi. + * It encapsulates basic operations and memory optimizations. */ +template +class Vector: public VectorBase { + public: + /// Constructor that takes no arguments. Initializes to empty. + Vector(): VectorBase() {} + + /// Constructor with specific size. Sets to all-zero by default + /// if set_zero == false, memory contents are undefined. + explicit Vector(const MatrixIndexT s, + MatrixResizeType resize_type = kSetZero) + : VectorBase() { Resize(s, resize_type); } + + /// Copy constructor from CUDA vector + /// This is defined in ../cudamatrix/cu-vector.h + template + explicit Vector(const CuVectorBase &cu); + + /// Copy constructor. The need for this is controversial. + Vector(const Vector &v) : VectorBase() { // (cannot be explicit) + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Copy-constructor from base-class, needed to copy from SubVector. + explicit Vector(const VectorBase &v) : VectorBase() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + + /// Type conversion constructor. + template + explicit Vector(const VectorBase &v): VectorBase() { + Resize(v.Dim(), kUndefined); + this->CopyFromVec(v); + } + +// Took this out since it is unsafe : Arnab +// /// Constructor from a pointer and a size; copies the data to a location +// /// it owns. +// Vector(const Real* Data, const MatrixIndexT s): VectorBase() { +// Resize(s); + // CopyFromPtr(Data, s); +// } + + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(Vector *other); + + /// Destructor. Deallocates memory. + ~Vector() { Destroy(); } + + /// Read function using C++ streams. Can also add to existing contents + /// of matrix. + void Read(std::istream &in, bool binary, bool add = false); + + /// Set vector to a specified size (can be zero). + /// The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero); + + /// Remove one element and shifts later elements down. + void RemoveElement(MatrixIndexT i); + + /// Assignment operator. + Vector &operator = (const Vector &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + + /// Assignment operator that takes VectorBase. + Vector &operator = (const VectorBase &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } + private: + /// Init assumes the current contents of the class are invalid (i.e. junk or + /// has already been freed), and it sets the vector to newly allocated memory + /// with the specified dimension. dim == 0 is acceptable. The memory contents + /// pointed to by data_ will be undefined. + void Init(const MatrixIndexT dim); + + /// Destroy function, called internally. + void Destroy(); + +}; + + +/// Represents a non-allocating general vector which can be defined +/// as a sub-vector of higher-level vector [or as the row of a matrix]. +template +class SubVector : public VectorBase { + public: + /// Constructor from a Vector or SubVector. + /// SubVectors are not const-safe and it's very hard to make them + /// so for now we just give up. This function contains const_cast. + SubVector(const VectorBase &t, const MatrixIndexT origin, + const MatrixIndexT length) : VectorBase() { + // following assert equiv to origin>=0 && length>=0 && + // origin+length <= rt.dim_ + KALDI_ASSERT(static_cast(origin)+ + static_cast(length) <= + static_cast(t.Dim())); + VectorBase::data_ = const_cast (t.Data()+origin); + VectorBase::dim_ = length; + } + + /// This constructor initializes the vector to point at the contents + /// of this packed matrix (SpMatrix or TpMatrix). + SubVector(const PackedMatrix &M) { + VectorBase::data_ = const_cast (M.Data()); + VectorBase::dim_ = (M.NumRows()*(M.NumRows()+1))/2; + } + + /// Copy constructor + SubVector(const SubVector &other) : VectorBase () { + // this copy constructor needed for Range() to work in base class. + VectorBase::data_ = other.data_; + VectorBase::dim_ = other.dim_; + } + + /// Constructor from a pointer to memory and a length. Keeps a pointer + /// to the data but does not take ownership (will never delete). + /// Caution: this constructor enables you to evade const constraints. + SubVector(const Real *data, MatrixIndexT length) : VectorBase () { + VectorBase::data_ = const_cast(data); + VectorBase::dim_ = length; + } + + /// This operation does not preserve const-ness, so be careful. + SubVector(const MatrixBase &matrix, MatrixIndexT row) { + VectorBase::data_ = const_cast(matrix.RowData(row)); + VectorBase::dim_ = matrix.NumCols(); + } + + ~SubVector() {} ///< Destructor (does nothing; no pointers are owned here). + + private: + /// Disallow assignment operator. + SubVector & operator = (const SubVector &other) {} +}; + +/// @} end of "addtogroup matrix_group" +/// \addtogroup matrix_funcs_io +/// @{ +/// Output to a C++ stream. Non-binary by default (use Write for +/// binary output). +template +std::ostream & operator << (std::ostream & out, const VectorBase & v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template +std::istream & operator >> (std::istream & in, VectorBase & v); + +/// Input from a C++ stream. Will automatically read text or +/// binary data from the stream. +template +std::istream & operator >> (std::istream & in, Vector & v); +/// @} end of \addtogroup matrix_funcs_io + +/// \addtogroup matrix_funcs_scalar +/// @{ + + +template +bool ApproxEqual(const VectorBase &a, + const VectorBase &b, Real tol = 0.01) { + return a.ApproxEqual(b, tol); +} + +template +inline void AssertEqual(VectorBase &a, VectorBase &b, + float tol = 0.01) { + KALDI_ASSERT(a.ApproxEqual(b, tol)); +} + + +/// Returns dot product between v1 and v2. +template +Real VecVec(const VectorBase &v1, const VectorBase &v2); + +template +Real VecVec(const VectorBase &v1, const VectorBase &v2); + + +/// Returns \f$ v_1^T M v_2 \f$ . +/// Not as efficient as it could be where v1 == v2. +template +Real VecMatVec(const VectorBase &v1, const MatrixBase &M, + const VectorBase &v2); + +/// @} End of "addtogroup matrix_funcs_scalar" + + +} // namespace kaldi + +// we need to include the implementation +#include "matrix/kaldi-vector-inl.h" + + + +#endif // KALDI_MATRIX_KALDI_VECTOR_H_ diff --git a/speechx/speechx/kaldi/matrix/matrix-common.h b/speechx/speechx/kaldi/matrix/matrix-common.h new file mode 100644 index 00000000..f7047d71 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/matrix-common.h @@ -0,0 +1,111 @@ +// matrix/matrix-common.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_MATRIX_COMMON_H_ +#define KALDI_MATRIX_MATRIX_COMMON_H_ + +// This file contains some #includes, forward declarations +// and typedefs that are needed by all the main header +// files in this directory. + +#include "base/kaldi-common.h" + +namespace kaldi { +// this enums equal to CblasTrans and CblasNoTrans constants from CBLAS library +// we are writing them as literals because we don't want to include here matrix/kaldi-blas.h, +// which puts many symbols into global scope (like "real") via the header f2c.h +typedef enum { + kTrans = 112, // = CblasTrans + kNoTrans = 111 // = CblasNoTrans +} MatrixTransposeType; + +typedef enum { + kSetZero, + kUndefined, + kCopyData +} MatrixResizeType; + + +typedef enum { + kDefaultStride, + kStrideEqualNumCols, +} MatrixStrideType; + +typedef enum { + kTakeLower, + kTakeUpper, + kTakeMean, + kTakeMeanAndCheck +} SpCopyType; + +template class VectorBase; +template class Vector; +template class SubVector; +template class MatrixBase; +template class SubMatrix; +template class Matrix; +template class SpMatrix; +template class TpMatrix; +template class PackedMatrix; +template class SparseMatrix; + +// these are classes that won't be defined in this +// directory; they're mostly needed for friend declarations. +template class CuMatrixBase; +template class CuSubMatrix; +template class CuMatrix; +template class CuVectorBase; +template class CuSubVector; +template class CuVector; +template class CuPackedMatrix; +template class CuSpMatrix; +template class CuTpMatrix; +template class CuSparseMatrix; + +class CompressedMatrix; +class GeneralMatrix; + +/// This class provides a way for switching between double and float types. +template class OtherReal { }; // useful in reading+writing routines + // to switch double and float. +/// A specialized class for switching from float to double. +template<> class OtherReal { + public: + typedef double Real; +}; +/// A specialized class for switching from double to float. +template<> class OtherReal { + public: + typedef float Real; +}; + + +typedef int32 MatrixIndexT; +typedef int32 SignedMatrixIndexT; +typedef uint32 UnsignedMatrixIndexT; + +// If you want to use size_t for the index type, do as follows instead: +//typedef size_t MatrixIndexT; +//typedef ssize_t SignedMatrixIndexT; +//typedef size_t UnsignedMatrixIndexT; + +} + + + +#endif // KALDI_MATRIX_MATRIX_COMMON_H_ diff --git a/speechx/speechx/kaldi/matrix/matrix-functions-inl.h b/speechx/speechx/kaldi/matrix/matrix-functions-inl.h new file mode 100644 index 00000000..9fac851e --- /dev/null +++ b/speechx/speechx/kaldi/matrix/matrix-functions-inl.h @@ -0,0 +1,56 @@ +// matrix/matrix-functions-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ +#define KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ + +namespace kaldi { + +//! ComplexMul implements, inline, the complex multiplication b *= a. +template inline void ComplexMul(const Real &a_re, const Real &a_im, + Real *b_re, Real *b_im) { + Real tmp_re = (*b_re * a_re) - (*b_im * a_im); + *b_im = *b_re * a_im + *b_im * a_re; + *b_re = tmp_re; +} + +template inline void ComplexAddProduct(const Real &a_re, const Real &a_im, + const Real &b_re, const Real &b_im, + Real *c_re, Real *c_im) { + *c_re += b_re*a_re - b_im*a_im; + *c_im += b_re*a_im + b_im*a_re; +} + + +template inline void ComplexImExp(Real x, Real *a_re, Real *a_im) { + *a_re = std::cos(x); + *a_im = std::sin(x); +} + + +} // end namespace kaldi + + +#endif // KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ + diff --git a/speechx/speechx/kaldi/matrix/matrix-functions.cc b/speechx/speechx/kaldi/matrix/matrix-functions.cc new file mode 100644 index 00000000..496c09f5 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/matrix-functions.cc @@ -0,0 +1,773 @@ +// matrix/matrix-functions.cc + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc.; Jan Silovsky +// Yanmin Qian; Saarland University; Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + +#include "matrix/matrix-functions.h" +#include "matrix/sp-matrix.h" + +namespace kaldi { + +template void ComplexFt (const VectorBase &in, + VectorBase *out, bool forward) { + int exp_sign = (forward ? -1 : 1); + KALDI_ASSERT(out != NULL); + KALDI_ASSERT(in.Dim() == out->Dim()); + KALDI_ASSERT(in.Dim() % 2 == 0); + int twoN = in.Dim(), N = twoN / 2; + const Real *data_in = in.Data(); + Real *data_out = out->Data(); + + Real exp1N_re, exp1N_im; // forward -> exp(-2pi / N), backward -> exp(2pi / N). + Real fraction = exp_sign * M_2PI / static_cast(N); // forward -> -2pi/N, backward->-2pi/N + ComplexImExp(fraction, &exp1N_re, &exp1N_im); + + Real expm_re = 1.0, expm_im = 0.0; // forward -> exp(-2pi m / N). + + for (int two_m = 0; two_m < twoN; two_m+=2) { // For each output component. + Real expmn_re = 1.0, expmn_im = 0.0; // forward -> exp(-2pi m n / N). + Real sum_re = 0.0, sum_im = 0.0; // complex output for index m (the sum expression) + for (int two_n = 0; two_n < twoN; two_n+=2) { + ComplexAddProduct(data_in[two_n], data_in[two_n+1], + expmn_re, expmn_im, + &sum_re, &sum_im); + ComplexMul(expm_re, expm_im, &expmn_re, &expmn_im); + } + data_out[two_m] = sum_re; + data_out[two_m + 1] = sum_im; + + + if (two_m % 10 == 0) { // occasionally renew "expm" from scratch to avoid + // loss of precision. + int nextm = 1 + two_m/2; + Real fraction_mult = fraction * nextm; + ComplexImExp(fraction_mult, &expm_re, &expm_im); + } else { + ComplexMul(exp1N_re, exp1N_im, &expm_re, &expm_im); + } + } +} + +template +void ComplexFt (const VectorBase &in, + VectorBase *out, bool forward); +template +void ComplexFt (const VectorBase &in, + VectorBase *out, bool forward); + + +#define KALDI_COMPLEXFFT_BLOCKSIZE 8192 +// This #define affects how we recurse in ComplexFftRecursive. +// We assume that memory-caching happens on a scale at +// least as small as this. + + +//! ComplexFftRecursive is a recursive function that computes the +//! complex FFT of size N. The "nffts" arguments specifies how many +//! separate FFTs to compute in parallel (we assume the data for +//! each one is consecutive in memory). The "forward argument" +//! specifies whether to do the FFT (true) or IFFT (false), although +//! note that we do not include the factor of 1/N (the user should +//! do this if required. The iterators factor_begin and factor_end +//! point to the beginning and end (i.e. one past the last element) +//! of an array of small factors of N (typically prime factors). +//! See the comments below this code for the detailed equations +//! of the recursion. + + +template +void ComplexFftRecursive (Real *data, int nffts, int N, + const int *factor_begin, + const int *factor_end, bool forward, + Vector *tmp_vec) { + if (factor_begin == factor_end) { + KALDI_ASSERT(N == 1); + return; + } + + { // an optimization: compute in smaller blocks. + // this block of code could be removed and it would still work. + MatrixIndexT size_perblock = N * 2 * sizeof(Real); + if (nffts > 1 && size_perblock*nffts > KALDI_COMPLEXFFT_BLOCKSIZE) { // can break it up... + // Break up into multiple blocks. This is an optimization. We make + // no progress on the FFT when we do this. + int block_skip = KALDI_COMPLEXFFT_BLOCKSIZE / size_perblock; // n blocks per call + if (block_skip == 0) block_skip = 1; + if (block_skip < nffts) { + int blocks_left = nffts; + while (blocks_left > 0) { + int skip_now = std::min(blocks_left, block_skip); + ComplexFftRecursive(data, skip_now, N, factor_begin, factor_end, forward, tmp_vec); + blocks_left -= skip_now; + data += skip_now * N*2; + } + return; + } // else do the actual algorithm. + } // else do the actual algorithm. + } + + int P = *factor_begin; + KALDI_ASSERT(P > 1); + int Q = N / P; + + + if (P > 1 && Q > 1) { // Do the rearrangement. C.f. eq. (8) below. Transform + // (a) to (b). + Real *data_thisblock = data; + if (tmp_vec->Dim() < (MatrixIndexT)N) tmp_vec->Resize(N); + Real *data_tmp = tmp_vec->Data(); + for (int thisfft = 0; thisfft < nffts; thisfft++, data_thisblock+=N*2) { + for (int offset = 0; offset < 2; offset++) { // 0 == real, 1 == im. + for (int p = 0; p < P; p++) { + for (int q = 0; q < Q; q++) { + int aidx = q*P + p, bidx = p*Q + q; + data_tmp[bidx] = data_thisblock[2*aidx+offset]; + } + } + for (int n = 0;n < P*Q;n++) data_thisblock[2*n+offset] = data_tmp[n]; + } + } + } + + { // Recurse. + ComplexFftRecursive(data, nffts*P, Q, factor_begin+1, factor_end, forward, tmp_vec); + } + + int exp_sign = (forward ? -1 : 1); + Real rootN_re, rootN_im; // Nth root of unity. + ComplexImExp(static_cast(exp_sign * M_2PI / N), &rootN_re, &rootN_im); + + Real rootP_re, rootP_im; // Pth root of unity. + ComplexImExp(static_cast(exp_sign * M_2PI / P), &rootP_re, &rootP_im); + + { // Do the multiplication + // could avoid a bunch of complex multiplies by moving the loop over data_thisblock + // inside. + if (tmp_vec->Dim() < (MatrixIndexT)(P*2)) tmp_vec->Resize(P*2); + Real *temp_a = tmp_vec->Data(); + + Real *data_thisblock = data, *data_end = data+(N*2*nffts); + for (; data_thisblock != data_end; data_thisblock += N*2) { // for each separate fft. + Real qd_re = 1.0, qd_im = 0.0; // 1^(q'/N) + for (int qd = 0; qd < Q; qd++) { + Real pdQ_qd_re = qd_re, pdQ_qd_im = qd_im; // 1^((p'Q+q') / N) == 1^((p'/P) + (q'/N)) + // Initialize to q'/N, corresponding to p' == 0. + for (int pd = 0; pd < P; pd++) { // pd == p' + { // This is the p = 0 case of the loop below [an optimization]. + temp_a[pd*2] = data_thisblock[qd*2]; + temp_a[pd*2 + 1] = data_thisblock[qd*2 + 1]; + } + { // This is the p = 1 case of the loop below [an optimization] + // **** MOST OF THE TIME (>60% I think) gets spent here. *** + ComplexAddProduct(pdQ_qd_re, pdQ_qd_im, + data_thisblock[(qd+Q)*2], data_thisblock[(qd+Q)*2 + 1], + &(temp_a[pd*2]), &(temp_a[pd*2 + 1])); + } + if (P > 2) { + Real p_pdQ_qd_re = pdQ_qd_re, p_pdQ_qd_im = pdQ_qd_im; // 1^(p(p'Q+q')/N) + for (int p = 2; p < P; p++) { + ComplexMul(pdQ_qd_re, pdQ_qd_im, &p_pdQ_qd_re, &p_pdQ_qd_im); // p_pdQ_qd *= pdQ_qd. + int data_idx = p*Q + qd; + ComplexAddProduct(p_pdQ_qd_re, p_pdQ_qd_im, + data_thisblock[data_idx*2], data_thisblock[data_idx*2 + 1], + &(temp_a[pd*2]), &(temp_a[pd*2 + 1])); + } + } + if (pd != P-1) + ComplexMul(rootP_re, rootP_im, &pdQ_qd_re, &pdQ_qd_im); // pdQ_qd *= (rootP == 1^{1/P}) + // (using 1/P == Q/N) + } + for (int pd = 0; pd < P; pd++) { + data_thisblock[(pd*Q + qd)*2] = temp_a[pd*2]; + data_thisblock[(pd*Q + qd)*2 + 1] = temp_a[pd*2 + 1]; + } + ComplexMul(rootN_re, rootN_im, &qd_re, &qd_im); // qd *= rootN. + } + } + } +} + +/* Equations for ComplexFftRecursive. + We consider here one of the "nffts" separate ffts; it's just a question of + doing them all in parallel. We also write all equations in terms of + complex math (the conversion to real arithmetic is not hard, and anyway + takes place inside function calls). + + + Let the input (i.e. "data" at start) be a_n, n = 0..N-1, and + the output (Fourier transform) be d_k, k = 0..N-1. We use these letters because + there will be two intermediate variables b and c. + We want to compute: + + d_k = \sum_n a_n 1^(kn/N) (1) + + where we use 1^x as shorthand for exp(-2pi x) for the forward algorithm + and exp(2pi x) for the backward one. + + We factorize N = P Q (P small, Q usually large). + With p = 0..P-1 and q = 0..Q-1, and also p'=0..P-1 and q'=0..P-1, we let: + + k == p'Q + q' (2) + n == qP + p (3) + + That is, we let p, q, p', q' range over these indices and observe that this way we + can cover all n, k. Expanding (1) using (2) and (3), we can write: + + d_k = \sum_{p, q} a_n 1^((p'Q+q')(qP+p)/N) + = \sum_{p, q} a_n 1^(p'pQ/N) 1^(q'qP/N) 1^(q'p/N) (4) + + using 1^(PQ/N) = 1 to get rid of the terms with PQ in them. Rearranging (4), + + d_k = \sum_p 1^(p'pQ/N) 1^(q'p/N) \sum_q 1^(q'qP/N) a_n (5) + + The point here is to separate the index q. Now we can expand out the remaining + instances of k and n using (2) and (3): + + d_(p'Q+q') = \sum_p 1^(p'pQ/N) 1^(q'p/N) \sum_q 1^(q'qP/N) a_(qP+p) (6) + + The expression \sum_q varies with the indices p and q'. Let us define + + C_{p, q'} = \sum_q 1^(q'qP/N) a_(qP+p) (7) + + Here, C_{p, q'}, viewed as a sequence in q', is just the DFT of the points + a_(qP+p) for q = 1..Q-1. These points are not consecutive in memory though, + they jump by P each time. Let us define b as a rearranged version of a, + so that + + b_(pQ+q) = a_(qP+p) (8) + + How to do this rearrangement in place? In + + We can rearrange (7) to be written in terms of the b's, using (8), so that + + C_{p, q'} = \sum_q 1^(q'q (P/N)) b_(pQ+q) (9) + + Here, the sequence of C_{p, q'} over q'=0..Q-1, is just the DFT of the sequence + of b_(pQ) .. b_(p(Q+1)-1). Let's arrange the C_{p, q'} in a single array in + memory in the same way as the b's, i.e. we define + c_(pQ+q') == C_{p, q'}. (10) + Note that we could have written (10) with q in place of q', as there is only + one index of type q present, but q' is just a more natural variable name to use + since we use q' elsewhere to subscript c and C. + + Rewriting (9), we have: + c_(pQ+q') = \sum_q 1^(q'q (P/N)) b_(pQ+q) (11) + which is the DFT computed by the recursive call to this function [after computing + the b's by rearranging the a's]. From the c's we want to compute the d's. + Taking (6), substituting in the sum (7), and using (10) to write it as an array, + we have: + d_(p'Q+q') = \sum_p 1^(p'pQ/N) 1^(q'p/N) c_(pQ+q') (12) + This sum is independent for different values of q'. Note that d overwrites c + in memory. We compute this in a direct way, using a little array of size P to + store the computed d values for one value of q' (we reuse the array for each value + of q'). + + So the overall picture is this: + We get a call to compute DFT on size N. + + - If N == 1 we return (nothing to do). + - We factor N = P Q (typically, P is small). + - Using (8), we rearrange the data in memory so that we have b not a in memory + (this is the block "do the rearrangement"). + The pseudocode for this is as follows. For simplicity we use a temporary array. + + for p = 0..P-1 + for q = 0..Q-1 + bidx = pQ + q + aidx = qP + p + tmp[bidx] = data[aidx]. + end + end + data <-- tmp + else + + endif + + + The reason this accomplishes (8) is that we want pQ+q and qP+p to be swapped + over for each p, q, and the "if m > n" is a convenient way of ensuring that + this swapping happens only once (otherwise it would happen twice, since pQ+q + and qP+p both range over the entire set of numbers 0..N-1). + + - We do the DFT on the smaller block size to compute c from b (this eq eq. (11)). + Note that this is actually multiple DFTs, one for each value of p, but this + goes to the "nffts" argument of the function call, which we have ignored up to now. + + -We compute eq. (12) via a loop, as follows + allocate temporary array e of size P. + For q' = 0..Q-1: + for p' = 0..P-1: + set sum to zero [this will go in e[p']] + for p = p..P-1: + sum += 1^(p'pQ/N) 1^(q'p/N) c_(pQ+q') + end + e[p'] = sum + end + for p' = 0..P-1: + d_(p'Q+q') = e[p'] + end + end + delete temporary array e + +*/ + +// This is the outer-layer calling code for ComplexFftRecursive. +// It factorizes the dimension and then calls the FFT routine. +template void ComplexFft(VectorBase *v, bool forward, Vector *tmp_in) { + KALDI_ASSERT(v != NULL); + + if (v->Dim()<=1) return; + KALDI_ASSERT(v->Dim() % 2 == 0); // complex input. + int N = v->Dim() / 2; + std::vector factors; + Factorize(N, &factors); + int *factor_beg = NULL; + if (factors.size() > 0) + factor_beg = &(factors[0]); + Vector tmp; // allocated in ComplexFftRecursive. + ComplexFftRecursive(v->Data(), 1, N, factor_beg, factor_beg+factors.size(), forward, (tmp_in?tmp_in:&tmp)); +} + +//! Inefficient version of Fourier transform, for testing purposes. +template void RealFftInefficient (VectorBase *v, bool forward) { + KALDI_ASSERT(v != NULL); + MatrixIndexT N = v->Dim(); + KALDI_ASSERT(N%2 == 0); + if (N == 0) return; + Vector vtmp(N*2); // store as complex. + if (forward) { + for (MatrixIndexT i = 0; i < N; i++) vtmp(i*2) = (*v)(i); + ComplexFft(&vtmp, forward); // this is already tested so we can use this. + v->CopyFromVec( vtmp.Range(0, N) ); + (*v)(1) = vtmp(N); // Copy the N/2'th fourier component, which is real, + // to the imaginary part of the 1st complex output. + } else { + // reverse the transformation above to get the complex spectrum. + vtmp(0) = (*v)(0); // copy F_0 which is real + vtmp(N) = (*v)(1); // copy F_{N/2} which is real + for (MatrixIndexT i = 1; i < N/2; i++) { + // Copy i'th to i'th fourier component + vtmp(2*i) = (*v)(2*i); + vtmp(2*i+1) = (*v)(2*i+1); + // Copy i'th to N-i'th, conjugated. + vtmp(2*(N-i)) = (*v)(2*i); + vtmp(2*(N-i)+1) = -(*v)(2*i+1); + } + ComplexFft(&vtmp, forward); // actually backward since forward == false + // Copy back real part. Complex part should be zero. + for (MatrixIndexT i = 0; i < N; i++) + (*v)(i) = vtmp(i*2); + } +} + +template void RealFftInefficient (VectorBase *v, bool forward); +template void RealFftInefficient (VectorBase *v, bool forward); + +template +void ComplexFft(VectorBase *v, bool forward, Vector *tmp_in); +template +void ComplexFft(VectorBase *v, bool forward, Vector *tmp_in); + + +// See the long comment below for the math behind this. +template void RealFft (VectorBase *v, bool forward) { + KALDI_ASSERT(v != NULL); + MatrixIndexT N = v->Dim(), N2 = N/2; + KALDI_ASSERT(N%2 == 0); + if (N == 0) return; + + if (forward) ComplexFft(v, true); + + Real *data = v->Data(); + Real rootN_re, rootN_im; // exp(-2pi/N), forward; exp(2pi/N), backward + int forward_sign = forward ? -1 : 1; + ComplexImExp(static_cast(M_2PI/N *forward_sign), &rootN_re, &rootN_im); + Real kN_re = -forward_sign, kN_im = 0.0; // exp(-2pik/N), forward; exp(-2pik/N), backward + // kN starts out as 1.0 for forward algorithm but -1.0 for backward. + for (MatrixIndexT k = 1; 2*k <= N2; k++) { + ComplexMul(rootN_re, rootN_im, &kN_re, &kN_im); + + Real Ck_re, Ck_im, Dk_re, Dk_im; + // C_k = 1/2 (B_k + B_{N/2 - k}^*) : + Ck_re = 0.5 * (data[2*k] + data[N - 2*k]); + Ck_im = 0.5 * (data[2*k + 1] - data[N - 2*k + 1]); + // re(D_k)= 1/2 (im(B_k) + im(B_{N/2-k})): + Dk_re = 0.5 * (data[2*k + 1] + data[N - 2*k + 1]); + // im(D_k) = -1/2 (re(B_k) - re(B_{N/2-k})) + Dk_im =-0.5 * (data[2*k] - data[N - 2*k]); + // A_k = C_k + 1^(k/N) D_k: + data[2*k] = Ck_re; // A_k <-- C_k + data[2*k+1] = Ck_im; + // now A_k += D_k 1^(k/N) + ComplexAddProduct(Dk_re, Dk_im, kN_re, kN_im, &(data[2*k]), &(data[2*k+1])); + + MatrixIndexT kdash = N2 - k; + if (kdash != k) { + // Next we handle the index k' = N/2 - k. This is necessary + // to do now, to avoid invalidating data that we will later need. + // The quantities C_{k'} and D_{k'} are just the conjugates of C_k + // and D_k, so the equations are simple modifications of the above, + // replacing Ck_im and Dk_im with their negatives. + data[2*kdash] = Ck_re; // A_k' <-- C_k' + data[2*kdash+1] = -Ck_im; + // now A_k' += D_k' 1^(k'/N) + // We use 1^(k'/N) = 1^((N/2 - k) / N) = 1^(1/2) 1^(-k/N) = -1 * (1^(k/N))^* + // so it's the same as 1^(k/N) but with the real part negated. + ComplexAddProduct(Dk_re, -Dk_im, -kN_re, kN_im, &(data[2*kdash]), &(data[2*kdash+1])); + } + } + + { // Now handle k = 0. + // In simple terms: after the complex fft, data[0] becomes the sum of real + // parts input[0], input[2]... and data[1] becomes the sum of imaginary + // pats input[1], input[3]... + // "zeroth" [A_0] is just the sum of input[0]+input[1]+input[2].. + // and "n2th" [A_{N/2}] is input[0]-input[1]+input[2]... . + Real zeroth = data[0] + data[1], + n2th = data[0] - data[1]; + data[0] = zeroth; + data[1] = n2th; + if (!forward) { + data[0] /= 2; + data[1] /= 2; + } + } + + if (!forward) { + ComplexFft(v, false); + v->Scale(2.0); // This is so we get a factor of N increase, rather than N/2 which we would + // otherwise get from [ComplexFft, forward] + [ComplexFft, backward] in dimension N/2. + // It's for consistency with our normal FFT convensions. + } +} + +template void RealFft (VectorBase *v, bool forward); +template void RealFft (VectorBase *v, bool forward); + +/* Notes for real FFTs. + We are using the same convention as above, 1^x to mean exp(-2\pi x) for the forward transform. + Actually, in a slight abuse of notation, we use this meaning for 1^x in both the forward and + backward cases because it's more convenient in this section. + + Suppose we have real data a[0...N-1], with N even, and want to compute its Fourier transform. + We can make do with the first N/2 points of the transform, since the remaining ones are complex + conjugates of the first. We want to compute: + for k = 0...N/2-1, + A_k = \sum_{n = 0}^{N-1} a_n 1^(kn/N) (1) + + We treat a[0..N-1] as a complex sequence of length N/2, i.e. a sequence b[0..N/2 - 1]. + Viewed as sequences of length N/2, we have: + b = c + i d, + where c = a_0, a_2 ... and d = a_1, a_3 ... + + We can recover the length-N/2 Fourier transforms of c and d by doing FT on b and + then doing the equations below. Derivation is marked by (*) in a comment below (search + for it). Let B, C, D be the FTs. + We have + C_k = 1/2 (B_k + B_{N/2 - k}^*) (z0) + D_k =-1/2i (B_k - B_{N/2 - k}^*) (z1) +so: re(D_k)= 1/2 (im(B_k) + im(B_{N/2-k})) (z2) + im(D_k) = -1/2 (re(B_k) - re(B_{N/2-k})) (z3) + + To recover the FT A from C and D, we write, rearranging (1): + + A_k = \sum_{n = 0, 2, ..., N-2} a_n 1^(kn/N) + +\sum_{n = 1, 3, ..., N-1} a_n 1^(kn/N) + = \sum_{n = 0, 1, ..., N/2-1} a_n 1^(2kn/N) + a_{n+1} 1^(2kn/N) 1^(k/N) + = \sum_{n = 0, 1, ..., N/2-1} c_n 1^(2kn/N) + d_n 1^(2kn/N) 1^(k/N) + A_k = C_k + 1^(k/N) D_k (a0) + + This equation is valid for k = 0...N/2-1, which is the range of the sequences B_k and + C_k. We don't use is for k = 0, which is a special case considered below. For + 1 < k < N/2, it's convenient to consider the pair k, k', where k' = N/2 - k. + Remember that C_k' = C_k^ *and D_k' = D_k^* [where * is conjugation]. Also, + 1^(N/2 / N) = -1. So we have: + A_k' = C_k^* - 1^(k/N) D_k^* (a0b) + We do (a0) and (a0b) together. + + + + By symmetry this gives us the Fourier components for N/2+1, ... N, if we want + them. However, it doesn't give us the value for exactly k = N/2. For k = 0 and k = N/2, it + is easiest to argue directly about the meaning of the A_k, B_k and C_k in terms of + sums of points. + A_0 and A_{N/2} are both real, with A_0=\sum_n a_n, and A_1 an alternating sum + A_1 = a_0 - a_1 + a_2 ... + It's easy to show that + A_0 = B_0 + C_0 (a1) + A_{N/2} = B_0 - C_0. (a2) + Since B_0 and C_0 are both real, B_0 is the real coefficient of D_0 and C_0 is the + imaginary coefficient. + + *REVERSING THE PROCESS* + + Next we want to reverse this process. We just need to work out C_k and D_k from the + sequence A_k. Then we do the inverse complex fft and we get back where we started. + For 0 and N/2, working from (a1) and (a2) above, we can see that: + B_0 = 1/2 (A_0 + A_{N/2}) (y0) + C_0 = 1/2 (A_0 + A_{N/2}) (y1) + and we use + D_0 = B_0 + i C_0 + to get the 1st complex coefficient of D. This is exactly the same as the forward process + except with an extra factor of 1/2. + + Consider equations (a0) and (a0b). We want to work out C_k and D_k from A_k and A_k'. Remember + k' = N/2 - k. + + Write down + A_k = C_k + 1^(k/N) D_k (copying a0) + A_k'^* = C_k - 1^(k/N) D_k (conjugate of a0b) + So + C_k = 0.5 (A_k + A_k'^*) (p0) + D_k = 1^(-k/N) . 0.5 (A_k - A_k'^*) (p1) + Next, we want to compute B_k and B_k' from C_k and D_k. C.f. (z0)..(z3), and remember + that k' = N/2-k. We can see + that + B_k = C_k + i D_k (p2) + B_k' = C_k - i D_k (p3) + + We would like to make the equations (p0) ... (p3) look like the forward equations (z0), (z1), + (a0) and (a0b) so we can reuse the code. Define E_k = -i 1^(k/N) D_k. Then write down (p0)..(p3). + We have + C_k = 0.5 (A_k + A_k'^*) (p0') + E_k = -0.5 i (A_k - A_k'^*) (p1') + B_k = C_k - 1^(-k/N) E_k (p2') + B_k' = C_k + 1^(-k/N) E_k (p3') + So these are exactly the same as (z0), (z1), (a0), (a0b) except replacing 1^(k/N) with + -1^(-k/N) . Remember that we defined 1^x above to be exp(-2pi x/N), so the signs here + might be opposite to what you see in the code. + + MODIFICATION: we need to take care of a factor of two. The complex FFT we implemented + does not divide by N in the reverse case. So upon inversion we get larger by N/2. + However, this is not consistent with normal FFT conventions where you get a factor of N. + For this reason we multiply by two after the process described above. + +*/ + + +/* + (*) [this token is referred to in a comment above]. + + Notes for separating 2 real transforms from one complex one. Note that the + letters here (A, B, C and N) are all distinct from the same letters used in the + place where this comment is used. + Suppose we + have two sequences a_n and b_n, n = 0..N-1. We combine them into a complex + number, + c_n = a_n + i b_n. + Then we take the fourier transform to get + C_k = \sum_{n = 0}^{N-1} c_n 1^(n/N) . + Then we use symmetry. Define A_k and B_k as the DFTs of a and b. + We use A_k = A_{N-k}^*, and B_k = B_{N-k}^*, since a and b are real. Using + C_k = A_k + i B_k, + C_{N-k} = A_k^* + i B_k^* + = A_k^* - (i B_k)^* + So: + A_k = 1/2 (C_k + C_{N-k}^*) + i B_k = 1/2 (C_k - C_{N-k}^*) +-> B_k =-1/2i (C_k - C_{N-k}^*) +-> re(B_k) = 1/2 (im(C_k) + im(C_{N-k})) + im(B_k) =-1/2 (re(C_k) - re(C_{N-k})) + + */ + +template void ComputeDctMatrix(Matrix *M) { + //KALDI_ASSERT(M->NumRows() == M->NumCols()); + MatrixIndexT K = M->NumRows(); + MatrixIndexT N = M->NumCols(); + + KALDI_ASSERT(K > 0); + KALDI_ASSERT(N > 0); + Real normalizer = std::sqrt(1.0 / static_cast(N)); // normalizer for + // X_0. + for (MatrixIndexT j = 0; j < N; j++) (*M)(0, j) = normalizer; + normalizer = std::sqrt(2.0 / static_cast(N)); // normalizer for other + // elements. + for (MatrixIndexT k = 1; k < K; k++) + for (MatrixIndexT n = 0; n < N; n++) + (*M)(k, n) = normalizer + * std::cos( static_cast(M_PI)/N * (n + 0.5) * k ); +} + + +template void ComputeDctMatrix(Matrix *M); +template void ComputeDctMatrix(Matrix *M); + + +template +void ComputePca(const MatrixBase &X, + MatrixBase *U, + MatrixBase *A, + bool print_eigs, + bool exact) { + // Note that some of these matrices may be transposed w.r.t. the + // way it's most natural to describe them in math... it's the rows + // of X and U that correspond to the (data-points, basis elements). + MatrixIndexT N = X.NumRows(), D = X.NumCols(); + // N = #points, D = feature dim. + KALDI_ASSERT(U != NULL && U->NumCols() == D); + MatrixIndexT G = U->NumRows(); // # of retained basis elements. + KALDI_ASSERT(A == NULL || (A->NumRows() == N && A->NumCols() == G)); + KALDI_ASSERT(G <= N && G <= D); + if (D < N) { // Do conventional PCA. + SpMatrix Msp(D); // Matrix of outer products. + Msp.AddMat2(1.0, X, kTrans, 0.0); // M <-- X^T X + Matrix Utmp; + Vector l; + if (exact) { + Utmp.Resize(D, D); + l.Resize(D); + //Matrix M(Msp); + //M.DestructiveSvd(&l, &Utmp, NULL); + Msp.Eig(&l, &Utmp); + } else { + Utmp.Resize(D, G); + l.Resize(G); + Msp.TopEigs(&l, &Utmp); + } + SortSvd(&l, &Utmp); + + for (MatrixIndexT g = 0; g < G; g++) + U->Row(g).CopyColFromMat(Utmp, g); + if (print_eigs) + KALDI_LOG << (exact ? "" : "Retained ") + << "PCA eigenvalues are " << l; + if (A != NULL) + A->AddMatMat(1.0, X, kNoTrans, *U, kTrans, 0.0); + } else { // Do inner-product PCA. + SpMatrix Nsp(N); // Matrix of inner products. + Nsp.AddMat2(1.0, X, kNoTrans, 0.0); // M <-- X X^T + + Matrix Vtmp; + Vector l; + if (exact) { + Vtmp.Resize(N, N); + l.Resize(N); + Matrix Nmat(Nsp); + Nmat.DestructiveSvd(&l, &Vtmp, NULL); + } else { + Vtmp.Resize(N, G); + l.Resize(G); + Nsp.TopEigs(&l, &Vtmp); + } + + MatrixIndexT num_zeroed = 0; + for (MatrixIndexT g = 0; g < G; g++) { + if (l(g) < 0.0) { + KALDI_WARN << "In PCA, setting element " << l(g) << " to zero."; + l(g) = 0.0; + num_zeroed++; + } + } + SortSvd(&l, &Vtmp); // Make sure zero elements are last, this + // is necessary for Orthogonalize() to work properly later. + + Vtmp.Transpose(); // So eigenvalues are the rows. + + for (MatrixIndexT g = 0; g < G; g++) { + Real sqrtlg = sqrt(l(g)); + if (l(g) != 0.0) { + U->Row(g).AddMatVec(1.0 / sqrtlg, X, kTrans, Vtmp.Row(g), 0.0); + } else { + U->Row(g).SetZero(); + (*U)(g, g) = 1.0; // arbitrary direction. Will later orthogonalize. + } + if (A != NULL) + for (MatrixIndexT n = 0; n < N; n++) + (*A)(n, g) = sqrtlg * Vtmp(g, n); + } + // Now orthogonalize. This is mainly useful in + // case there were zero eigenvalues, but we do it + // for all of them. + U->OrthogonalizeRows(); + if (print_eigs) + KALDI_LOG << "(inner-product) PCA eigenvalues are " << l; + } +} + + +template +void ComputePca(const MatrixBase &X, + MatrixBase *U, + MatrixBase *A, + bool print_eigs, + bool exact); + +template +void ComputePca(const MatrixBase &X, + MatrixBase *U, + MatrixBase *A, + bool print_eigs, + bool exact); + + +// Added by Dan, Feb. 13 2012. +// This function does: *plus += max(0, a b^T), +// *minus += max(0, -(a b^T)). +template +void AddOuterProductPlusMinus(Real alpha, + const VectorBase &a, + const VectorBase &b, + MatrixBase *plus, + MatrixBase *minus) { + KALDI_ASSERT(a.Dim() == plus->NumRows() && b.Dim() == plus->NumCols() + && a.Dim() == minus->NumRows() && b.Dim() == minus->NumCols()); + int32 nrows = a.Dim(), ncols = b.Dim(), pskip = plus->Stride() - ncols, + mskip = minus->Stride() - ncols; + const Real *adata = a.Data(), *bdata = b.Data(); + Real *plusdata = plus->Data(), *minusdata = minus->Data(); + + for (int32 i = 0; i < nrows; i++) { + const Real *btmp = bdata; + Real multiple = alpha * *adata; + if (multiple > 0.0) { + for (int32 j = 0; j < ncols; j++, plusdata++, minusdata++, btmp++) { + if (*btmp > 0.0) *plusdata += multiple * *btmp; + else *minusdata -= multiple * *btmp; + } + } else { + for (int32 j = 0; j < ncols; j++, plusdata++, minusdata++, btmp++) { + if (*btmp < 0.0) *plusdata += multiple * *btmp; + else *minusdata -= multiple * *btmp; + } + } + plusdata += pskip; + minusdata += mskip; + adata++; + } +} + +// Instantiate template +template +void AddOuterProductPlusMinus(float alpha, + const VectorBase &a, + const VectorBase &b, + MatrixBase *plus, + MatrixBase *minus); +template +void AddOuterProductPlusMinus(double alpha, + const VectorBase &a, + const VectorBase &b, + MatrixBase *plus, + MatrixBase *minus); + + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/matrix-functions.h b/speechx/speechx/kaldi/matrix/matrix-functions.h new file mode 100644 index 00000000..ca50ddda --- /dev/null +++ b/speechx/speechx/kaldi/matrix/matrix-functions.h @@ -0,0 +1,174 @@ +// matrix/matrix-functions.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc.; Jan Silovsky; +// Yanmin Qian; 1991 Henrique (Rico) Malvar (*) +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_MATRIX_FUNCTIONS_H_ +#define KALDI_MATRIX_MATRIX_FUNCTIONS_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// @addtogroup matrix_funcs_misc +/// @{ + +/** The function ComplexFft does an Fft on the vector argument v. + v is a vector of even dimension, interpreted for both input + and output as a vector of complex numbers i.e. + \f[ v = ( re_0, im_0, re_1, im_1, ... ) \f] + + If "forward == true" this routine does the Discrete Fourier Transform + (DFT), i.e.: + \f[ vout[m] \leftarrow \sum_{n = 0}^{N-1} vin[i] exp( -2pi m n / N ) \f] + + If "backward" it does the Inverse Discrete Fourier Transform (IDFT) + *WITHOUT THE FACTOR 1/N*, + i.e.: + \f[ vout[m] <-- \sum_{n = 0}^{N-1} vin[i] exp( 2pi m n / N ) \f] + [note the sign difference on the 2 pi for the backward one.] + + Note that this is the definition of the FT given in most texts, but + it differs from the Numerical Recipes version in which the forward + and backward algorithms are flipped. + + Note that you would have to multiply by 1/N after the IDFT to get + back to where you started from. We don't do this because + in some contexts, the transform is made symmetric by multiplying + by sqrt(N) in both passes. The user can do this by themselves. + + See also SplitRadixComplexFft, declared in srfft.h, which is more efficient + but only works if the length of the input is a power of 2. + */ +template void ComplexFft (VectorBase *v, bool forward, Vector *tmp_work = NULL); + +/// ComplexFt is the same as ComplexFft but it implements the Fourier +/// transform in an inefficient way. It is mainly included for testing purposes. +/// See comment for ComplexFft to describe the input and outputs and what it does. +template void ComplexFt (const VectorBase &in, + VectorBase *out, bool forward); + +/// RealFft is a fourier transform of real inputs. Internally it uses +/// ComplexFft. The input dimension N must be even. If forward == true, +/// it transforms from a sequence of N real points to its complex fourier +/// transform; otherwise it goes in the reverse direction. If you call it +/// in the forward and then reverse direction and multiply by 1.0/N, you +/// will get back the original data. +/// The interpretation of the complex-FFT data is as follows: the array +/// is a sequence of complex numbers C_n of length N/2 with (real, im) format, +/// i.e. [real0, real_{N/2}, real1, im1, real2, im2, real3, im3, ...]. +/// See also SplitRadixRealFft, declared in srfft.h, which is more efficient +/// but only works if the length of the input is a power of 2. + +template void RealFft (VectorBase *v, bool forward); + + +/// RealFt has the same input and output format as RealFft above, but it is +/// an inefficient implementation included for testing purposes. +template void RealFftInefficient (VectorBase *v, bool forward); + +/// ComputeDctMatrix computes a matrix corresponding to the DCT, such that +/// M * v equals the DCT of vector v. M must be square at input. +/// This is the type = III DCT with normalization, corresponding to the +/// following equations, where x is the signal and X is the DCT: +/// X_0 = 1/sqrt(2*N) \sum_{n = 0}^{N-1} x_n +/// X_k = 1/sqrt(N) \sum_{n = 0}^{N-1} x_n cos( \pi/N (n + 1/2) k ) +/// This matrix's transpose is its own inverse, so transposing this +/// matrix will give the inverse DCT. +/// Caution: the type III DCT is generally known as the "inverse DCT" (with the +/// type II being the actual DCT), so this function is somewhatd mis-named. It +/// was probably done this way for HTK compatibility. We don't change it +/// because it was this way from the start and changing it would affect the +/// feature generation. + +template void ComputeDctMatrix(Matrix *M); + + +/// ComplexMul implements, inline, the complex multiplication b *= a. +template inline void ComplexMul(const Real &a_re, const Real &a_im, + Real *b_re, Real *b_im); + +/// ComplexMul implements, inline, the complex operation c += (a * b). +template inline void ComplexAddProduct(const Real &a_re, const Real &a_im, + const Real &b_re, const Real &b_im, + Real *c_re, Real *c_im); + + +/// ComplexImExp implements a <-- exp(i x), inline. +template inline void ComplexImExp(Real x, Real *a_re, Real *a_im); + + + +/** + ComputePCA does a PCA computation, using either outer products + or inner products, whichever is more efficient. Let D be + the dimension of the data points, N be the number of data + points, and G be the PCA dimension we want to retain. We assume + G <= N and G <= D. + + @param X [in] An N x D matrix. Each row of X is a point x_i. + @param U [out] A G x D matrix. Each row of U is a basis element u_i. + @param A [out] An N x D matrix, or NULL. Each row of A is a set of coefficients + in the basis for a point x_i, so A(i, g) is the coefficient of u_i + in x_i. + @param print_eigs [in] If true, prints out diagnostic information about the + eigenvalues. + @param exact [in] If true, does the exact computation; if false, does + a much faster (but almost exact) computation based on the Lanczos + method. +*/ + +template +void ComputePca(const MatrixBase &X, + MatrixBase *U, + MatrixBase *A, + bool print_eigs = false, + bool exact = true); + + + +// This function does: *plus += max(0, a b^T), +// *minus += max(0, -(a b^T)). +template +void AddOuterProductPlusMinus(Real alpha, + const VectorBase &a, + const VectorBase &b, + MatrixBase *plus, + MatrixBase *minus); + +template +inline void AssertSameDim(const MatrixBase &mat1, const MatrixBase &mat2) { + KALDI_ASSERT(mat1.NumRows() == mat2.NumRows() + && mat1.NumCols() == mat2.NumCols()); +} + + +/// @} end of "addtogroup matrix_funcs_misc" + +} // end namespace kaldi + +#include "matrix/matrix-functions-inl.h" + + +#endif diff --git a/speechx/speechx/kaldi/matrix/matrix-lib.h b/speechx/speechx/kaldi/matrix/matrix-lib.h new file mode 100644 index 00000000..2a5ebad7 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/matrix-lib.h @@ -0,0 +1,37 @@ +// matrix/matrix-lib.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// Include everything from this directory. +// These files include other stuff that we need. +#ifndef KALDI_MATRIX_MATRIX_LIB_H_ +#define KALDI_MATRIX_MATRIX_LIB_H_ + +#include "base/kaldi-common.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/sp-matrix.h" +#include "matrix/tp-matrix.h" +#include "matrix/matrix-functions.h" +#include "matrix/srfft.h" +#include "matrix/compressed-matrix.h" +#include "matrix/sparse-matrix.h" +#include "matrix/optimization.h" + +#endif + diff --git a/speechx/speechx/kaldi/matrix/optimization.cc b/speechx/speechx/kaldi/matrix/optimization.cc new file mode 100644 index 00000000..c17b5b94 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/optimization.cc @@ -0,0 +1,577 @@ +// matrix/optimization.cc + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) + + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + +#include + +#include "matrix/optimization.h" +#include "matrix/sp-matrix.h" + +namespace kaldi { + + +// Below, N&W refers to Nocedal and Wright, "Numerical Optimization", 2nd Ed. + +template +OptimizeLbfgs::OptimizeLbfgs(const VectorBase &x, + const LbfgsOptions &opts): + opts_(opts), k_(0), computation_state_(kBeforeStep), H_was_set_(false) { + KALDI_ASSERT(opts.m > 0); // dimension. + MatrixIndexT dim = x.Dim(); + KALDI_ASSERT(dim > 0); + x_ = x; // this is the value of x_k + new_x_ = x; // this is where we'll evaluate the function next. + deriv_.Resize(dim); + temp_.Resize(dim); + data_.Resize(2 * opts.m, dim); + rho_.Resize(opts.m); + // Just set f_ to some invalid value, as we haven't yet set it. + f_ = (opts.minimize ? 1 : -1 ) * std::numeric_limits::infinity(); + best_f_ = f_; + best_x_ = x_; +} + + +template +Real OptimizeLbfgs::RecentStepLength() const { + size_t n = step_lengths_.size(); + if (n == 0) return std::numeric_limits::infinity(); + else { + if (n >= 2 && step_lengths_[n-1] == 0.0 && step_lengths_[n-2] == 0.0) + return 0.0; // two zeros in a row means repeated restarts, which is + // a loop. Short-circuit this by returning zero. + Real avg = 0.0; + for (size_t i = 0; i < n; i++) + avg += step_lengths_[i] / n; + return avg; + } +} + +template +void OptimizeLbfgs::ComputeHifNeeded(const VectorBase &gradient) { + if (k_ == 0) { + if (H_.Dim() == 0) { + // H was never set up. Set it up for the first time. + Real learning_rate; + if (opts_.first_step_length > 0.0) { // this takes + // precedence over first_step_learning_rate, if set. + // We are setting up H for the first time. + Real gradient_length = gradient.Norm(2.0); + learning_rate = (gradient_length > 0.0 ? + opts_.first_step_length / gradient_length : + 1.0); + } else if (opts_.first_step_impr > 0.0) { + Real gradient_length = gradient.Norm(2.0); + learning_rate = (gradient_length > 0.0 ? + opts_.first_step_impr / (gradient_length * gradient_length) : + 1.0); + } else { + learning_rate = opts_.first_step_learning_rate; + } + H_.Resize(x_.Dim()); + KALDI_ASSERT(learning_rate > 0.0); + H_.Set(opts_.minimize ? learning_rate : -learning_rate); + } + } else { // k_ > 0 + if (!H_was_set_) { // The user never specified an approximate + // diagonal inverse Hessian. + // Set it using formula 7.20: H_k^{(0)} = \gamma_k I, where + // \gamma_k = s_{k-1}^T y_{k-1} / y_{k-1}^T y_{k-1} + SubVector y_km1 = Y(k_-1); + double gamma_k = VecVec(S(k_-1), y_km1) / VecVec(y_km1, y_km1); + if (KALDI_ISNAN(gamma_k) || KALDI_ISINF(gamma_k)) { + KALDI_WARN << "NaN encountered in L-BFGS (already converged?)"; + gamma_k = (opts_.minimize ? 1.0 : -1.0); + } + H_.Set(gamma_k); + } + } +} + +// This represents the first 2 lines of Algorithm 7.5 (N&W), which +// in fact is mostly a call to Algorithm 7.4. +// Note: this is valid whether we are minimizing or maximizing. +template +void OptimizeLbfgs::ComputeNewDirection(Real function_value, + const VectorBase &gradient) { + KALDI_ASSERT(computation_state_ == kBeforeStep); + SignedMatrixIndexT m = M(), k = k_; + ComputeHifNeeded(gradient); + // The rest of this is computing p_k <-- - H_k \nabla f_k using Algorithm + // 7.4 of N&W. + Vector &q(deriv_), &r(new_x_); // Use deriv_ as a temporary place to put + // q, and new_x_ as a temporay place to put r. + // The if-statement below is just to get rid of spurious warnings from + // valgrind about memcpy source and destination overlap, since sometimes q and + // gradient are the same variable. + if (&q != &gradient) + q.CopyFromVec(gradient); // q <-- \nabla f_k. + Vector alpha(m); + // for i = k - 1, k - 2, ... k - m + for (SignedMatrixIndexT i = k - 1; + i >= std::max(k - m, static_cast(0)); + i--) { + alpha(i % m) = rho_(i % m) * VecVec(S(i), q); // \alpha_i <-- \rho_i s_i^T q. + q.AddVec(-alpha(i % m), Y(i)); // q <-- q - \alpha_i y_i + } + r.SetZero(); + r.AddVecVec(1.0, H_, q, 0.0); // r <-- H_k^{(0)} q. + // for k = k - m, k - m + 1, ... , k - 1 + for (SignedMatrixIndexT i = std::max(k - m, static_cast(0)); + i < k; + i++) { + Real beta = rho_(i % m) * VecVec(Y(i), r); // \beta <-- \rho_i y_i^T r + r.AddVec(alpha(i % m) - beta, S(i)); // r <-- r + s_i (\alpha_i - \beta) + } + + { // TEST. Note, -r will be the direction. + Real dot = VecVec(gradient, r); + if ((opts_.minimize && dot < 0) || (!opts_.minimize && dot > 0)) + KALDI_WARN << "Step direction has the wrong sign! Routine will fail."; + } + + // Now we're out of Alg. 7.4 and back into Alg. 7.5. + // Alg. 7.4 returned r (using new_x_ as the location), and with \alpha_k = 1 + // as the initial guess, we're setting x_{k+1} = x_k + \alpha_k p_k, with + // p_k = -r [hence the statement new_x_.Scale(-1.0)]., and \alpha_k = 1. + // This is the first place we'll get the user to evaluate the function; + // any backtracking (or acceptance of that step) occurs inside StepSizeIteration. + // We're still within iteration k; we haven't yet finalized the step size. + new_x_.Scale(-1.0); + new_x_.AddVec(1.0, x_); + if (&deriv_ != &gradient) + deriv_.CopyFromVec(gradient); + f_ = function_value; + d_ = opts_.d; + num_wolfe_i_failures_ = 0; + num_wolfe_ii_failures_ = 0; + last_failure_type_ = kNone; + computation_state_ = kWithinStep; +} + + +template +bool OptimizeLbfgs::AcceptStep(Real function_value, + const VectorBase &gradient) { + // Save s_k = x_{k+1} - x_{k}, and y_k = \nabla f_{k+1} - \nabla f_k. + SubVector s = S(k_), y = Y(k_); + s.CopyFromVec(new_x_); + s.AddVec(-1.0, x_); // s = new_x_ - x_. + y.CopyFromVec(gradient); + y.AddVec(-1.0, deriv_); // y = gradient - deriv_. + + // Warning: there is a division in the next line. This could + // generate inf or nan, but this wouldn't necessarily be an error + // at this point because for zero step size or derivative we should + // terminate the iterations. But this is up to the calling code. + Real prod = VecVec(y, s); + rho_(k_ % opts_.m) = 1.0 / prod; + Real len = s.Norm(2.0); + + if ((opts_.minimize && prod <= 1.0e-20) || (!opts_.minimize && prod >= -1.0e-20) + || len == 0.0) + return false; // This will force restart. + + KALDI_VLOG(3) << "Accepted step; length was " << len + << ", prod was " << prod; + RecordStepLength(len); + + // store x_{k+1} and the function value f_{k+1}. + x_.CopyFromVec(new_x_); + f_ = function_value; + k_++; + + return true; // We successfully accepted the step. +} + +template +void OptimizeLbfgs::RecordStepLength(Real s) { + step_lengths_.push_back(s); + if (step_lengths_.size() > static_cast(opts_.avg_step_length)) + step_lengths_.erase(step_lengths_.begin(), step_lengths_.begin() + 1); +} + + +template +void OptimizeLbfgs::Restart(const VectorBase &x, + Real f, + const VectorBase &gradient) { + // Note: we will consider restarting (the transition of x_ -> x) + // as a step, even if it has zero step size. This is necessary in + // order for convergence to be detected. + { + Vector &diff(temp_); + diff.CopyFromVec(x); + diff.AddVec(-1.0, x_); + RecordStepLength(diff.Norm(2.0)); + } + k_ = 0; // Restart the iterations! [But note that the Hessian, + // whatever it was, stays as before.] + if (&x_ != &x) + x_.CopyFromVec(x); + new_x_.CopyFromVec(x); + f_ = f; + computation_state_ = kBeforeStep; + ComputeNewDirection(f, gradient); +} + +template +void OptimizeLbfgs::StepSizeIteration(Real function_value, + const VectorBase &gradient) { + KALDI_VLOG(3) << "In step size iteration, function value changed " + << f_ << " to " << function_value; + + // We're in some part of the backtracking, and the user is providing + // the objective function value and gradient. + // We're checking two conditions: Wolfe i) [the Armijo rule] and + // Wolfe ii). + + // The Armijo rule (when minimizing) is: + // f(k_k + \alpha_k p_k) <= f(x_k) + c_1 \alpha_k p_k^T \nabla f(x_k), where + // \nabla means the derivative. + // Below, "temp" is the RHS of this equation, where (\alpha_k p_k) equals + // (new_x_ - x_); we don't store \alpha or p_k separately, they are implicit + // as the difference new_x_ - x_. + + // Below, pf is \alpha_k p_k^T \nabla f(x_k). + Real pf = VecVec(new_x_, deriv_) - VecVec(x_, deriv_); + Real temp = f_ + opts_.c1 * pf; + + bool wolfe_i_ok; + if (opts_.minimize) wolfe_i_ok = (function_value <= temp); + else wolfe_i_ok = (function_value >= temp); + + // Wolfe condition ii) can be written as: + // p_k^T \nabla f(x_k + \alpha_k p_k) >= c_2 p_k^T \nabla f(x_k) + // p2f equals \alpha_k p_k^T \nabla f(x_k + \alpha_k p_k), where + // (\alpha_k p_k^T) is (new_x_ - x_). + // Note that in our version of Wolfe condition (ii) we have an extra + // factor alpha, which doesn't affect anything. + Real p2f = VecVec(new_x_, gradient) - VecVec(x_, gradient); + //eps = (sizeof(Real) == 4 ? 1.0e-05 : 1.0e-10) * + //(std::abs(p2f) + std::abs(pf)); + bool wolfe_ii_ok; + if (opts_.minimize) wolfe_ii_ok = (p2f >= opts_.c2 * pf); + else wolfe_ii_ok = (p2f <= opts_.c2 * pf); + + enum { kDecrease, kNoChange } d_action; // What do do with d_: leave it alone, + // or take the square root. + enum { kAccept, kDecreaseStep, kIncreaseStep, kRestart } iteration_action; + // What we'll do in the overall iteration: accept this value, DecreaseStep + // (reduce the step size), IncreaseStep (increase the step size), or kRestart + // (set k back to zero). Generally when we can't get both conditions to be + // true with a reasonable period of time, it makes sense to restart, because + // probably we've almost converged and got into numerical issues; from here + // we'll just produced NaN's. Restarting is a safe thing to do and the outer + // code will quickly detect convergence. + + d_action = kNoChange; // the default. + + if (wolfe_i_ok && wolfe_ii_ok) { + iteration_action = kAccept; + d_action = kNoChange; // actually doesn't matter, it'll get reset. + } else if (!wolfe_i_ok) { + // If wolfe i) [the Armijo rule] failed then we went too far (or are + // meeting numerical problems). + if (last_failure_type_ == kWolfeII) { // Last time we failed it was Wolfe ii). + // When we switch between them we decrease d. + d_action = kDecrease; + } + iteration_action = kDecreaseStep; + last_failure_type_ = kWolfeI; + num_wolfe_i_failures_++; + } else if (!wolfe_ii_ok) { + // Curvature condition failed -> we did not go far enough. + if (last_failure_type_ == kWolfeI) // switching between wolfe i and ii failures-> + d_action = kDecrease; // decrease value of d. + iteration_action = kIncreaseStep; + last_failure_type_ = kWolfeII; + num_wolfe_ii_failures_++; + } + + // Test whether we've been switching too many times betwen wolfe i) and ii) + // failures, or overall have an excessive number of failures. We just give up + // and restart L-BFGS. Probably we've almost converged. + if (num_wolfe_i_failures_ + num_wolfe_ii_failures_ > + opts_.max_line_search_iters) { + KALDI_VLOG(2) << "Too many steps in line search -> restarting."; + iteration_action = kRestart; + } + + if (d_action == kDecrease) + d_ = std::sqrt(d_); + + KALDI_VLOG(3) << "d = " << d_ << ", iter = " << k_ << ", action = " + << (iteration_action == kAccept ? "accept" : + (iteration_action == kDecreaseStep ? "decrease" : + (iteration_action == kIncreaseStep ? "increase" : + "reject"))); + + // Note: even if iteration_action != Restart at this point, + // some code below may set it to Restart. + if (iteration_action == kAccept) { + if (AcceptStep(function_value, gradient)) { // If we did + // not detect a problem while accepting the step.. + computation_state_ = kBeforeStep; + ComputeNewDirection(function_value, gradient); + } else { + KALDI_VLOG(2) << "Restarting L-BFGS computation; problem found while " + << "accepting step."; + iteration_action = kRestart; // We'll have to restart now. + } + } + if (iteration_action == kDecreaseStep || iteration_action == kIncreaseStep) { + Real scale = (iteration_action == kDecreaseStep ? 1.0 / d_ : d_); + temp_.CopyFromVec(new_x_); + new_x_.Scale(scale); + new_x_.AddVec(1.0 - scale, x_); + if (new_x_.ApproxEqual(temp_, 0.0)) { + // Value of new_x_ did not change at all --> we must restart. + KALDI_VLOG(3) << "Value of x did not change, when taking step; " + << "will restart computation."; + iteration_action = kRestart; + } + if (new_x_.ApproxEqual(temp_, 1.0e-08) && + std::abs(f_ - function_value) < 1.0e-08 * + std::abs(f_) && iteration_action == kDecreaseStep) { + // This is common and due to roundoff. + KALDI_VLOG(3) << "We appear to be backtracking while we are extremely " + << "close to the old value; restarting."; + iteration_action = kRestart; + } + + if (iteration_action == kDecreaseStep) { + num_wolfe_i_failures_++; + last_failure_type_ = kWolfeI; + } else { + num_wolfe_ii_failures_++; + last_failure_type_ = kWolfeII; + } + } + if (iteration_action == kRestart) { + // We want to restart the computation. If the objf at new_x_ is + // better than it was at x_, we'll start at new_x_, else at x_. + bool use_newx; + if (opts_.minimize) use_newx = (function_value < f_); + else use_newx = (function_value > f_); + KALDI_VLOG(3) << "Restarting computation."; + if (use_newx) Restart(new_x_, function_value, gradient); + else Restart(x_, f_, deriv_); + } +} + +template +void OptimizeLbfgs::DoStep(Real function_value, + const VectorBase &gradient) { + if (opts_.minimize ? function_value < best_f_ : function_value > best_f_) { + best_f_ = function_value; + best_x_.CopyFromVec(new_x_); + } + if (computation_state_ == kBeforeStep) + ComputeNewDirection(function_value, gradient); + else // kWithinStep{1,2,3} + StepSizeIteration(function_value, gradient); +} + +template +void OptimizeLbfgs::DoStep(Real function_value, + const VectorBase &gradient, + const VectorBase &diag_approx_2nd_deriv) { + if (opts_.minimize ? function_value < best_f_ : function_value > best_f_) { + best_f_ = function_value; + best_x_.CopyFromVec(new_x_); + } + if (opts_.minimize) { + KALDI_ASSERT(diag_approx_2nd_deriv.Min() > 0.0); + } else { + KALDI_ASSERT(diag_approx_2nd_deriv.Max() < 0.0); + } + H_was_set_ = true; + H_.CopyFromVec(diag_approx_2nd_deriv); + H_.InvertElements(); + DoStep(function_value, gradient); +} + +template +const VectorBase& +OptimizeLbfgs::GetValue(Real *objf_value) const { + if (objf_value != NULL) *objf_value = best_f_; + return best_x_; +} + +// to compute the alpha, we are minimizing f(x) = x^T b - 0.5 x_k^T A x_k along +// direction p_k... consider alpha +// d/dx of f(x) = b - A x_k = r. + +// Notation based on Sec. 5.1 of Nocedal and Wright +// Computation based on Alg. 5.2 of Nocedal and Wright (Pg. 112) +// Notation (replicated for convenience): +// To solve Ax=b for x +// k : current iteration +// x_k : estimate of x (at iteration k) +// r_k : residual ( r_k \eqdef A x_k - b ) +// \alpha_k : step size +// p_k : A-conjugate direction +// \beta_k : coefficient used in A-conjugate direction computation for next +// iteration +// +// Algo. LinearCG(A,b,x_0) +// ======================== +// r_0 = Ax_0 - b +// p_0 = -r_0 +// k = 0 +// +// while r_k != 0 +// \alpha_k = (r_k^T r_k) / (p_k^T A p_k) +// x_{k+1} = x_k + \alpha_k p_k; +// r_{k+1} = r_k + \alpha_k A p_k +// \beta_{k+1} = \frac{r_{k+1}^T r_{k+1}}{r_k^T r_K} +// p_{k+1} = -r_{k+1} + \beta_{k+1} p_k +// k = k + 1 +// end + +template +int32 LinearCgd(const LinearCgdOptions &opts, + const SpMatrix &A, + const VectorBase &b, + VectorBase *x) { + // Initialize the variables + // + int32 M = A.NumCols(); + + Matrix storage(4, M); + SubVector r(storage, 0), p(storage, 1), Ap(storage, 2), x_orig(storage, 3); + p.CopyFromVec(b); + p.AddSpVec(-1.0, A, *x, 1.0); // p_0 = b - A x_0 + r.AddVec(-1.0, p); // r_0 = - p_0 + x_orig.CopyFromVec(*x); // in case of failure. + + Real r_cur_norm_sq = VecVec(r, r), + r_initial_norm_sq = r_cur_norm_sq, + r_recompute_norm_sq = r_cur_norm_sq; + + KALDI_VLOG(5) << "In linear CG: initial norm-square of residual = " + << r_initial_norm_sq; + + KALDI_ASSERT(opts.recompute_residual_factor <= 1.0); + Real max_error_sq = std::max(opts.max_error * opts.max_error, + std::numeric_limits::min()), + residual_factor = opts.recompute_residual_factor * + opts.recompute_residual_factor, + inv_residual_factor = 1.0 / residual_factor; + + // Note: although from a mathematical point of view the method should converge + // after M iterations, in practice (due to roundoff) it does not always + // converge to good precision after that many iterations so we let the maximum + // be M + 5 instead. + int32 k = 0; + for (; k < M + 5 && k != opts.max_iters; k++) { + // Note: we'll break from this loop if we converge sooner due to + // max_error. + Ap.AddSpVec(1.0, A, p, 0.0); // Ap = A p + + // Below is how the code used to look. + // // next line: \alpha_k = (r_k^T r_k) / (p_k^T A p_k) + // Real alpha = r_cur_norm_sq / VecVec(p, Ap); + // + // We changed r_cur_norm_sq below to -VecVec(p, r). Although this is + // slightly less efficient, it seems to make the algorithm dramatically more + // robust. Note that -p^T r is the mathematically more natural quantity to + // use here, that corresponds to minimizing along that direction... r^T r is + // recommended in Nocedal and Wright only as a kind of optimization as it is + // supposed to be the same as -p^T r and we already have it computed. + Real alpha = -VecVec(p, r) / VecVec(p, Ap); + + // next line: x_{k+1} = x_k + \alpha_k p_k; + x->AddVec(alpha, p); + // next line: r_{k+1} = r_k + \alpha_k A p_k + r.AddVec(alpha, Ap); + Real r_next_norm_sq = VecVec(r, r); + + if (r_next_norm_sq < residual_factor * r_recompute_norm_sq || + r_next_norm_sq > inv_residual_factor * r_recompute_norm_sq) { + + // Recompute the residual from scratch if the residual norm has decreased + // a lot; this costs an extra matrix-vector multiply, but helps keep the + // residual accurate. + // Also do the same if the residual norm has increased a lot since + // the last time we recomputed... this shouldn't happen often, but + // it can indicate bad stuff is happening. + + // r_{k+1} = A x_{k+1} - b + r.AddSpVec(1.0, A, *x, 0.0); + r.AddVec(-1.0, b); + r_next_norm_sq = VecVec(r, r); + r_recompute_norm_sq = r_next_norm_sq; + + KALDI_VLOG(5) << "In linear CG: recomputing residual."; + } + KALDI_VLOG(5) << "In linear CG: k = " << k + << ", r_next_norm_sq = " << r_next_norm_sq; + // Check if converged. + if (r_next_norm_sq <= max_error_sq) + break; + + // next line: \beta_{k+1} = \frac{r_{k+1}^T r_{k+1}}{r_k^T r_K} + Real beta_next = r_next_norm_sq / r_cur_norm_sq; + // next lines: p_{k+1} = -r_{k+1} + \beta_{k+1} p_k + Vector p_old(p); + p.Scale(beta_next); + p.AddVec(-1.0, r); + r_cur_norm_sq = r_next_norm_sq; + } + + // note: the first element of the && is only there to save compute. + // the residual r is A x - b, and r_cur_norm_sq and r_initial_norm_sq are + // of the form r * r, so it's clear that b * b has the right dimension to + // compare with the residual. + if (r_cur_norm_sq > r_initial_norm_sq && + r_cur_norm_sq > r_initial_norm_sq + 1.0e-10 * VecVec(b, b)) { + KALDI_WARN << "Doing linear CGD in dimension " << A.NumRows() << ", after " << k + << " iterations the squared residual has got worse, " + << r_cur_norm_sq << " > " << r_initial_norm_sq + << ". Will do an exact optimization."; + SolverOptions opts("called-from-linearCGD"); + x->CopyFromVec(x_orig); + SolveQuadraticProblem(A, b, opts, x); + } + return k; +} + +// Instantiate the class for float and double. +template +class OptimizeLbfgs; +template +class OptimizeLbfgs; + + +template +int32 LinearCgd(const LinearCgdOptions &opts, + const SpMatrix &A, const VectorBase &b, + VectorBase *x); + +template +int32 LinearCgd(const LinearCgdOptions &opts, + const SpMatrix &A, const VectorBase &b, + VectorBase *x); + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/optimization.h b/speechx/speechx/kaldi/matrix/optimization.h new file mode 100644 index 00000000..66309aca --- /dev/null +++ b/speechx/speechx/kaldi/matrix/optimization.h @@ -0,0 +1,248 @@ +// matrix/optimization.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_OPTIMIZATION_H_ +#define KALDI_MATRIX_OPTIMIZATION_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + + +/// @addtogroup matrix_optimization +/// @{ + +struct LinearCgdOptions { + int32 max_iters; // Maximum number of iters (if >= 0). + BaseFloat max_error; // Maximum 2-norm of the residual A x - b (convergence + // test) + // Every time the residual 2-norm decreases by this recompute_residual_factor + // since the last time it was computed from scratch, recompute it from + // scratch. This helps to keep the computed residual accurate even in the + // presence of roundoff. + BaseFloat recompute_residual_factor; + + LinearCgdOptions(): max_iters(-1), + max_error(0.0), + recompute_residual_factor(0.01) { } +}; + +/* + This function uses linear conjugate gradient descent to approximately solve + the system A x = b. The value of x at entry corresponds to the initial guess + of x. The algorithm continues until the number of iterations equals b.Dim(), + or until the 2-norm of (A x - b) is <= max_error, or until the number of + iterations equals max_iter, whichever happens sooner. It is a requirement + that A be positive definite. + It returns the number of iterations that were actually executed (this is + useful for testing purposes). +*/ +template +int32 LinearCgd(const LinearCgdOptions &opts, + const SpMatrix &A, const VectorBase &b, + VectorBase *x); + + + + + + +/** + This is an implementation of L-BFGS. It pushes responsibility for + determining when to stop, onto the user. There is no call-back here: + everything is done via calls to the class itself (see the example in + matrix-lib-test.cc). This does not implement constrained L-BFGS, but it will + handle constrained problems correctly as long as the function approaches + +infinity (or -infinity for maximization problems) when it gets close to the + bound of the constraint. In these types of problems, you just let the + function value be +infinity for minimization problems, or -infinity for + maximization problems, outside these bounds). +*/ + +struct LbfgsOptions { + bool minimize; // if true, we're minimizing, else maximizing. + int m; // m is the number of stored vectors L-BFGS keeps. + float first_step_learning_rate; // The very first step of L-BFGS is + // like gradient descent. If you want to configure the size of that step, + // you can do it using this variable. + float first_step_length; // If this variable is >0.0, it overrides + // first_step_learning_rate; on the first step we choose an approximate + // Hessian that is the multiple of the identity that would generate this + // step-length, or 1.0 if the gradient is zero. + float first_step_impr; // If this variable is >0.0, it overrides + // first_step_learning_rate; on the first step we choose an approximate + // Hessian that is the multiple of the identity that would generate this + // amount of objective function improvement (assuming the "real" objf + // was linear). + float c1; // A constant in Armijo rule = Wolfe condition i) + float c2; // A constant in Wolfe condition ii) + float d; // An amount > 1.0 (default 2.0) that we initially multiply or + // divide the step length by, in the line search. + int max_line_search_iters; // after this many iters we restart L-BFGS. + int avg_step_length; // number of iters to avg step length over, in + // RecentStepLength(). + + LbfgsOptions (bool minimize = true): + minimize(minimize), + m(10), + first_step_learning_rate(1.0), + first_step_length(0.0), + first_step_impr(0.0), + c1(1.0e-04), + c2(0.9), + d(2.0), + max_line_search_iters(50), + avg_step_length(4) { } +}; + +template +class OptimizeLbfgs { + public: + /// Initializer takes the starting value of x. + OptimizeLbfgs(const VectorBase &x, + const LbfgsOptions &opts); + + /// This returns the value of the variable x that has the best objective + /// function so far, and the corresponding objective function value if + /// requested. This would typically be called only at the end. + const VectorBase& GetValue(Real *objf_value = NULL) const; + + /// This returns the value at which the function wants us + /// to compute the objective function and gradient. + const VectorBase& GetProposedValue() const { return new_x_; } + + /// Returns the average magnitude of the last n steps (but not + /// more than the number we have stored). Before we have taken + /// any steps, returns +infinity. Note: if the most recent + /// step length was 0, it returns 0, regardless of the other + /// step lengths. This makes it suitable as a convergence test + /// (else we'd generate NaN's). + Real RecentStepLength() const; + + /// The user calls this function to provide the class with the + /// function and gradient info at the point GetProposedValue(). + /// If this point is outside the constraints you can set function_value + /// to {+infinity,-infinity} for {minimization,maximization} problems. + /// In this case the gradient, and also the second derivative (if you call + /// the second overloaded version of this function) will be ignored. + void DoStep(Real function_value, + const VectorBase &gradient); + + /// The user can call this version of DoStep() if it is desired to set some + /// kind of approximate Hessian on this iteration. Note: it is a prerequisite + /// that diag_approx_2nd_deriv must be strictly positive (minimizing), or + /// negative (maximizing). + void DoStep(Real function_value, + const VectorBase &gradient, + const VectorBase &diag_approx_2nd_deriv); + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(OptimizeLbfgs); + + + // The following variable says what stage of the computation we're at. + // Refer to Algorithm 7.5 (L-BFGS) of Nodecdal & Wright, "Numerical + // Optimization", 2nd edition. + // kBeforeStep means we're about to do + /// "compute p_k <-- - H_k \delta f_k" (i.e. Algorithm 7.4). + // kWithinStep means we're at some point within line search; note + // that line search is iterative so we can stay in this state more + // than one time on each iteration. + enum ComputationState { + kBeforeStep, + kWithinStep, // This means we're within the step-size computation, and + // have not yet done the 1st function evaluation. + }; + + inline MatrixIndexT Dim() { return x_.Dim(); } + inline MatrixIndexT M() { return opts_.m; } + SubVector Y(MatrixIndexT i) { + return SubVector(data_, (i % M()) * 2); // vector y_i + } + SubVector S(MatrixIndexT i) { + return SubVector(data_, (i % M()) * 2 + 1); // vector s_i + } + // The following are subroutines within DoStep(): + bool AcceptStep(Real function_value, + const VectorBase &gradient); + void Restart(const VectorBase &x, + Real function_value, + const VectorBase &gradient); + void ComputeNewDirection(Real function_value, + const VectorBase &gradient); + void ComputeHifNeeded(const VectorBase &gradient); + void StepSizeIteration(Real function_value, + const VectorBase &gradient); + void RecordStepLength(Real s); + + + LbfgsOptions opts_; + SignedMatrixIndexT k_; // Iteration number, starts from zero. Gets set back to zero + // when we restart. + + ComputationState computation_state_; + bool H_was_set_; // True if the user specified H_; if false, + // we'll use a heuristic to estimate it. + + + Vector x_; // current x. + Vector new_x_; // the x proposed in the line search. + Vector best_x_; // the x with the best objective function so far + // (either the same as x_ or something in the current line search.) + Vector deriv_; // The most recently evaluated derivative-- at x_k. + Vector temp_; + Real f_; // The function evaluated at x_k. + Real best_f_; // the best objective function so far. + Real d_; // a number d > 1.0, but during an iteration we may decrease this, when + // we switch between armijo and wolfe failures. + + int num_wolfe_i_failures_; // the num times we decreased step size. + int num_wolfe_ii_failures_; // the num times we increased step size. + enum { kWolfeI, kWolfeII, kNone } last_failure_type_; // last type of step-search + // failure on this iter. + + Vector H_; // Current inverse-Hessian estimate. May be computed by this class itself, + // or provided by user using 2nd form of SetGradientInfo(). + Matrix data_; // dimension (m*2) x dim. Even rows store + // gradients y_i, odd rows store steps s_i. + Vector rho_; // dimension m; rho_(m) = 1/(y_m^T s_m), Eq. 7.17. + + std::vector step_lengths_; // The step sizes we took on the last + // (up to m) iterations; these are not stored in a rotating buffer but + // are shifted by one each time (this is more convenient when we + // restart, as we keep this info past restarting). + + +}; + +/// @} + + +} // end namespace kaldi + + + +#endif + diff --git a/speechx/speechx/kaldi/matrix/packed-matrix.cc b/speechx/speechx/kaldi/matrix/packed-matrix.cc new file mode 100644 index 00000000..80bf5891 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/packed-matrix.cc @@ -0,0 +1,438 @@ +// matrix/packed-matrix.cc + +// Copyright 2009-2012 Microsoft Corporation Saarland University +// Johns Hopkins University (Author: Daniel Povey); +// Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +/** + * @file packed-matrix.cc + * + * Implementation of specialized PackedMatrix template methods + */ +#include "matrix/cblas-wrappers.h" +#include "matrix/packed-matrix.h" +#include "matrix/kaldi-vector.h" + +namespace kaldi { + +template +void PackedMatrix::Scale(Real alpha) { + size_t nr = num_rows_, + sz = (nr * (nr + 1)) / 2; + cblas_Xscal(sz, alpha, data_, 1); +} + +template +void PackedMatrix::AddPacked(const Real alpha, const PackedMatrix &rMa) { + KALDI_ASSERT(num_rows_ == rMa.NumRows()); + size_t nr = num_rows_, + sz = (nr * (nr + 1)) / 2; + cblas_Xaxpy(sz, alpha, rMa.Data(), 1, data_, 1); +} + +template +void PackedMatrix::SetRandn() { + Real *data = data_; + size_t dim = num_rows_, size = ((dim*(dim+1))/2); + for (size_t i = 0; i < size; i++) + data[i] = RandGauss(); +} + +template +inline void PackedMatrix::Init(MatrixIndexT r) { + if (r == 0) { + num_rows_ = 0; + data_ = 0; + return; + } + size_t size = ((static_cast(r) * static_cast(r + 1)) / 2); + + if (static_cast(static_cast(size)) != size) { + KALDI_WARN << "Allocating packed matrix whose full dimension does not fit " + << "in MatrixIndexT: not all code is tested for this case."; + } + + void *data; // aligned memory block + void *temp; + + if ((data = KALDI_MEMALIGN(16, size * sizeof(Real), &temp)) != NULL) { + this->data_ = static_cast (data); + this->num_rows_ = r; + } else { + throw std::bad_alloc(); + } +} + +template +void PackedMatrix::Swap(PackedMatrix *other) { + std::swap(data_, other->data_); + std::swap(num_rows_, other->num_rows_); +} + +template +void PackedMatrix::Swap(Matrix *other) { + std::swap(data_, other->data_); + std::swap(num_rows_, other->num_rows_); +} + + +template +void PackedMatrix::Resize(MatrixIndexT r, MatrixResizeType resize_type) { + // the next block uses recursion to handle what we have to do if + // resize_type == kCopyData. + if (resize_type == kCopyData) { + if (this->data_ == NULL || r == 0) resize_type = kSetZero; // nothing to copy. + else if (this->num_rows_ == r) { return; } // nothing to do. + else { + // set tmp to a packed matrix of the desired size. + PackedMatrix tmp(r, kUndefined); + size_t r_min = std::min(r, num_rows_); + size_t mem_size_min = sizeof(Real) * (r_min*(r_min+1))/2, + mem_size_full = sizeof(Real) * (r*(r+1))/2; + // Copy the contents to tmp. + memcpy(tmp.data_, data_, mem_size_min); + char *ptr = static_cast(static_cast(tmp.data_)); + // Set the rest of the contents of tmp to zero. + memset(static_cast(ptr + mem_size_min), 0, mem_size_full-mem_size_min); + tmp.Swap(this); + return; + } + } + if (data_ != NULL) Destroy(); + Init(r); + if (resize_type == kSetZero) SetZero(); +} + + + +template +void PackedMatrix::AddToDiag(Real r) { + Real *ptr = data_; + for (MatrixIndexT i = 2; i <= num_rows_+1; i++) { + *ptr += r; + ptr += i; + } +} + +template +void PackedMatrix::ScaleDiag(Real alpha) { + Real *ptr = data_; + for (MatrixIndexT i = 2; i <= num_rows_+1; i++) { + *ptr *= alpha; + ptr += i; + } +} + +template +void PackedMatrix::SetDiag(Real alpha) { + Real *ptr = data_; + for (MatrixIndexT i = 2; i <= num_rows_+1; i++) { + *ptr = alpha; + ptr += i; + } +} + + + +template +template +void PackedMatrix::CopyFromPacked(const PackedMatrix &orig) { + KALDI_ASSERT(NumRows() == orig.NumRows()); + if (sizeof(Real) == sizeof(OtherReal)) { + memcpy(data_, orig.Data(), SizeInBytes()); + } else { + Real *dst = data_; + const OtherReal *src = orig.Data(); + size_t nr = NumRows(), + size = (nr * (nr + 1)) / 2; + for (size_t i = 0; i < size; i++, dst++, src++) + *dst = *src; + } +} + +// template instantiations. +template +void PackedMatrix::CopyFromPacked(const PackedMatrix &orig); +template +void PackedMatrix::CopyFromPacked(const PackedMatrix &orig); +template +void PackedMatrix::CopyFromPacked(const PackedMatrix &orig); +template +void PackedMatrix::CopyFromPacked(const PackedMatrix &orig); + + + +template +template +void PackedMatrix::CopyFromVec(const SubVector &vec) { + MatrixIndexT size = (NumRows()*(NumRows()+1)) / 2; + KALDI_ASSERT(vec.Dim() == size); + if (sizeof(Real) == sizeof(OtherReal)) { + memcpy(data_, vec.Data(), size * sizeof(Real)); + } else { + Real *dst = data_; + const OtherReal *src = vec.Data(); + for (MatrixIndexT i = 0; i < size; i++, dst++, src++) + *dst = *src; + } +} + +// template instantiations. +template +void PackedMatrix::CopyFromVec(const SubVector &orig); +template +void PackedMatrix::CopyFromVec(const SubVector &orig); +template +void PackedMatrix::CopyFromVec(const SubVector &orig); +template +void PackedMatrix::CopyFromVec(const SubVector &orig); + + + +template +void PackedMatrix::SetZero() { + memset(data_, 0, SizeInBytes()); +} + +template +void PackedMatrix::SetUnit() { + memset(data_, 0, SizeInBytes()); + for (MatrixIndexT row = 0;row < num_rows_;row++) + (*this)(row, row) = 1.0; +} + +template +Real PackedMatrix::Trace() const { + Real ans = 0.0; + for (MatrixIndexT row = 0;row < num_rows_;row++) + ans += (*this)(row, row); + return ans; +} + +template +void PackedMatrix::Destroy() { + // we need to free the data block if it was defined + if (data_ != NULL) KALDI_MEMALIGN_FREE(data_); + data_ = NULL; + num_rows_ = 0; +} + + +template +void PackedMatrix::Write(std::ostream &os, bool binary) const { + if (!os.good()) { + KALDI_ERR << "Failed to write vector to stream: stream not good"; + } + + int32 size = this->NumRows(); // make the size 32-bit on disk. + KALDI_ASSERT(this->NumRows() == (MatrixIndexT) size); + MatrixIndexT num_elems = ((size+1)*(MatrixIndexT)size)/2; + + if(binary) { + std::string my_token = (sizeof(Real) == 4 ? "FP" : "DP"); + WriteToken(os, binary, my_token); + WriteBasicType(os, binary, size); + // We don't use the built-in Kaldi write routines for the floats, as they are + // not efficient enough. + os.write((const char*) data_, sizeof(Real) * num_elems); + } + else { + if(size == 0) + os<<"[ ]\n"; + else { + os<<"[\n"; + MatrixIndexT i = 0; + for (int32 j = 0; j < size; j++) { + for (int32 k = 0; k < j + 1; k++) { + WriteBasicType(os, binary, data_[i++]); + } + os << ( (j==size-1)? "]\n" : "\n"); + } + KALDI_ASSERT(i == num_elems); + } + } + if (os.fail()) { + KALDI_ERR << "Failed to write packed matrix to stream"; + } +} + +// template +// void Save (std::ostream & os, const PackedMatrix& rM) +// { +// const Real* p_elem = rM.data(); +// for (MatrixIndexT i = 0; i < rM.NumRows(); i++) { +// for (MatrixIndexT j = 0; j <= i ; j++) { +// os << *p_elem; +// p_elem++; +// if (j == i) { +// os << '\n'; +// } +// else { +// os << ' '; +// } +// } +// } +// if (os.fail()) +// KALDI_ERR("Failed to write packed matrix to stream"); +// } + + + + + +template +void PackedMatrix::Read(std::istream& is, bool binary, bool add) { + if (add) { + PackedMatrix tmp; + tmp.Read(is, binary, false); // read without adding. + if (this->NumRows() == 0) this->Resize(tmp.NumRows()); + else { + if (this->NumRows() != tmp.NumRows()) { + if (tmp.NumRows() == 0) return; // do nothing in this case. + else KALDI_ERR << "PackedMatrix::Read, size mismatch " << this->NumRows() + << " vs. " << tmp.NumRows(); + } + } + this->AddPacked(1.0, tmp); + return; + } // now assume add == false. + + std::ostringstream specific_error; + MatrixIndexT pos_at_start = is.tellg(); + int peekval = Peek(is, binary); + const char *my_token = (sizeof(Real) == 4 ? "FP" : "DP"); + const char *new_format_token = "["; + bool is_new_format = false;//added by hxu + char other_token_start = (sizeof(Real) == 4 ? 'D' : 'F'); + int32 size; + MatrixIndexT num_elems; + + if (peekval == other_token_start) { // need to instantiate the other type to read it. + typedef typename OtherReal::Real OtherType; // if Real == float, OtherType == double, and vice versa. + PackedMatrix other(this->NumRows()); + other.Read(is, binary, false); // add is false at this point. + this->Resize(other.NumRows()); + this->CopyFromPacked(other); + return; + } + std::string token; + ReadToken(is, binary, &token); + if (token != my_token) { + if(token != new_format_token) { + specific_error << ": Expected token " << my_token << ", got " << token; + goto bad; + } + //new format it is + is_new_format = true; + } + if(!is_new_format) { + ReadBasicType(is, binary, &size); // throws on error. + if ((MatrixIndexT)size != this->NumRows()) { + KALDI_ASSERT(size>=0); + this->Resize(size); + } + num_elems = ((size+1)*(MatrixIndexT)size)/2; + if (!binary) { + for (MatrixIndexT i = 0; i < num_elems; i++) { + ReadBasicType(is, false, data_+i); // will throw on error. + } + } else { + if (num_elems) + is.read(reinterpret_cast(data_), sizeof(Real)*num_elems); + } + if (is.fail()) goto bad; + return; + } + else { + std::vector data; + while(1) { + int32 num_lines = 0; + int i = is.peek(); + if (i == -1) { specific_error << "Got EOF while reading matrix data"; goto bad; } + else if (static_cast(i) == ']') { // Finished reading matrix. + is.get(); // eat the "]". + i = is.peek(); + if (static_cast(i) == '\r') { + is.get(); + is.get(); // get \r\n (must eat what we wrote) + }// I don't actually understand what it's doing here + else if (static_cast(i) == '\n') { is.get(); } // get \n (must eat what we wrote) + + if (is.fail()) { + KALDI_WARN << "After end of matrix data, read error."; + // we got the data we needed, so just warn for this error. + } + //now process the data: + num_lines = int32(sqrt(data.size()*2)); + + KALDI_ASSERT(data.size() == num_lines*(num_lines+1)/2); + + this->Resize(num_lines); + + //std::cout<= '0' && i <= '9') || i == '-' ) { // A number... + Real r; + is >> r; + if (is.fail()) { + specific_error << "Stream failure/EOF while reading matrix data."; + goto bad; + } + data.push_back(r); + } + else if (isspace(i)) { + is.get(); // eat the space and do nothing. + } else { // NaN or inf or error. + std::string str; + is >> str; + if (!KALDI_STRCASECMP(str.c_str(), "inf") || + !KALDI_STRCASECMP(str.c_str(), "infinity")) { + data.push_back(std::numeric_limits::infinity()); + KALDI_WARN << "Reading infinite value into matrix."; + } else if (!KALDI_STRCASECMP(str.c_str(), "nan")) { + data.push_back(std::numeric_limits::quiet_NaN()); + KALDI_WARN << "Reading NaN value into matrix."; + } else { + specific_error << "Expecting numeric matrix data, got " << str; + goto bad; + } + } + } + } +bad: + KALDI_ERR << "Failed to read packed matrix from stream. " << specific_error.str() + << " File position at start is " + << pos_at_start << ", currently " << is.tellg(); +} + + +// Instantiate PackedMatrix for float and double. +template +class PackedMatrix; + +template +class PackedMatrix; + + +} // namespace kaldi + diff --git a/speechx/speechx/kaldi/matrix/packed-matrix.h b/speechx/speechx/kaldi/matrix/packed-matrix.h new file mode 100644 index 00000000..722d932b --- /dev/null +++ b/speechx/speechx/kaldi/matrix/packed-matrix.h @@ -0,0 +1,197 @@ +// matrix/packed-matrix.h + +// Copyright 2009-2013 Ondrej Glembek; Lukas Burget; Microsoft Corporation; +// Saarland University; Yanmin Qian; +// Johns Hopkins University (Author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_PACKED_MATRIX_H_ +#define KALDI_MATRIX_PACKED_MATRIX_H_ + +#include "matrix/matrix-common.h" +#include + +namespace kaldi { + +/// \addtogroup matrix_funcs_io +// we need to declare the friend << operator here +template +std::ostream & operator <<(std::ostream & out, const PackedMatrix& M); + + +/// \addtogroup matrix_group +/// @{ + +/// @brief Packed matrix: base class for triangular and symmetric matrices. +template class PackedMatrix { + friend class CuPackedMatrix; + public: + //friend class CuPackedMatrix; + + PackedMatrix() : data_(NULL), num_rows_(0) {} + + explicit PackedMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero): + data_(NULL) { Resize(r, resize_type); } + + explicit PackedMatrix(const PackedMatrix &orig) : data_(NULL) { + Resize(orig.num_rows_, kUndefined); + CopyFromPacked(orig); + } + + template + explicit PackedMatrix(const PackedMatrix &orig) : data_(NULL) { + Resize(orig.NumRows(), kUndefined); + CopyFromPacked(orig); + } + + void SetZero(); /// < Set to zero + void SetUnit(); /// < Set to unit matrix. + void SetRandn(); /// < Set to random values of a normal distribution + + Real Trace() const; + + // Needed for inclusion in std::vector + PackedMatrix & operator =(const PackedMatrix &other) { + Resize(other.NumRows()); + CopyFromPacked(other); + return *this; + } + + ~PackedMatrix() { + Destroy(); + } + + /// Set packed matrix to a specified size (can be zero). + /// The value of the new data depends on resize_type: + /// -if kSetZero, the new data will be zero + /// -if kUndefined, the new data will be undefined + /// -if kCopyData, the new data will be the same as the old data in any + /// shared positions, and zero elsewhere. + /// This function takes time proportional to the number of data elements. + void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero); + + void AddToDiag(const Real r); // Adds r to diaginal + + void ScaleDiag(const Real alpha); // Scales diagonal by alpha. + + void SetDiag(const Real alpha); // Sets diagonal to this value. + + template + void CopyFromPacked(const PackedMatrix &orig); + + /// CopyFromVec just interprets the vector as having the same layout + /// as the packed matrix. Must have the same dimension, i.e. + /// orig.Dim() == (NumRows()*(NumRows()+1)) / 2; + template + void CopyFromVec(const SubVector &orig); + + Real* Data() { return data_; } + const Real* Data() const { return data_; } + inline MatrixIndexT NumRows() const { return num_rows_; } + inline MatrixIndexT NumCols() const { return num_rows_; } + size_t SizeInBytes() const { + size_t nr = static_cast(num_rows_); + return ((nr * (nr+1)) / 2) * sizeof(Real); + } + + //MatrixIndexT Stride() const { return stride_; } + + // This code is duplicated in child classes to avoid extra levels of calls. + Real operator() (MatrixIndexT r, MatrixIndexT c) const { + KALDI_ASSERT(static_cast(r) < + static_cast(num_rows_) && + static_cast(c) < + static_cast(num_rows_) + && c <= r); + return *(data_ + (r * (r + 1)) / 2 + c); + } + + // This code is duplicated in child classes to avoid extra levels of calls. + Real &operator() (MatrixIndexT r, MatrixIndexT c) { + KALDI_ASSERT(static_cast(r) < + static_cast(num_rows_) && + static_cast(c) < + static_cast(num_rows_) + && c <= r); + return *(data_ + (r * (r + 1)) / 2 + c); + } + + Real Max() const { + KALDI_ASSERT(num_rows_ > 0); + return * (std::max_element(data_, data_ + ((num_rows_*(num_rows_+1))/2) )); + } + + Real Min() const { + KALDI_ASSERT(num_rows_ > 0); + return * (std::min_element(data_, data_ + ((num_rows_*(num_rows_+1))/2) )); + } + + void Scale(Real c); + + friend std::ostream & operator << <> (std::ostream & out, + const PackedMatrix &m); + // Use instead of stream<<*this, if you want to add to existing contents. + // Will throw exception on failure. + void Read(std::istream &in, bool binary, bool add = false); + + void Write(std::ostream &out, bool binary) const; + + void Destroy(); + + /// Swaps the contents of *this and *other. Shallow swap. + void Swap(PackedMatrix *other); + void Swap(Matrix *other); + + + protected: + // Will only be called from this class or derived classes. + void AddPacked(const Real alpha, const PackedMatrix& M); + Real *data_; + MatrixIndexT num_rows_; + //MatrixIndexT stride_; + private: + /// Init assumes the current contents of the class are is invalid (i.e. junk or + /// has already been freed), and it sets the matrixd to newly allocated memory + /// with the specified dimension. dim == 0 is acceptable. The memory contents + /// pointed to by data_ will be undefined. + void Init(MatrixIndexT dim); + +}; +/// @} end "addtogroup matrix_group" + + +/// \addtogroup matrix_funcs_io +/// @{ + +template +std::ostream & operator << (std::ostream & os, const PackedMatrix& M) { + M.Write(os, false); + return os; +} + +template +std::istream & operator >> (std::istream &is, PackedMatrix &M) { + M.Read(is, false); + return is; +} + +/// @} + +} // namespace kaldi + +#endif + diff --git a/speechx/speechx/kaldi/matrix/qr.cc b/speechx/speechx/kaldi/matrix/qr.cc new file mode 100644 index 00000000..861dead0 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/qr.cc @@ -0,0 +1,580 @@ +// matrix/qr.cc + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "matrix/sp-matrix.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/matrix-functions.h" +#include "matrix/cblas-wrappers.h" + +// This file contains an implementation of the Symmetric QR Algorithm +// for the symmetric eigenvalue problem. See Golub and Van Loan, +// 3rd ed., Algorithm 8.3.3. + +namespace kaldi { + + +/* This is from Golub and Van Loan 3rd ed., sec. 5.1.3, + p210. + x is the input of dimenson 'dim', v is the output of dimension + dim, and beta is a scalar. Note: we use zero-based + not one-based indexing. */ +/* +// We are commenting out the function below ("House") because it's not +// needed, but we keep it just to show how we came up with HouseBackward. +template +void House(MatrixIndexT dim, const Real *x, Real *v, Real *beta) { + KALDI_ASSERT(dim > 0); + // To avoid overflow, we first compute the max of x_ (or + // one if that's zero, and we'll replace "x" by x/max(x_i) + // below. The householder vector is anyway invariant to + // the magnitude of x. We could actually avoid this extra loop + // over x if we wanted to be a bit smarter, but anyway this + // doesn't dominate the O(N) performance of the algorithm. + Real s; // s is a scale on x. + { + Real max_x = std::numeric_limits::min(); + for (MatrixIndexT i = 0; i < dim; i++) + max_x = std::max(max_x, (x[i] < 0 ? -x[i] : x[i])); + if (max_x == 0.0) max_x = 1.0; + s = 1.0 / max_x; + } + + Real sigma = 0.0; + v[0] = 1.0; + for (MatrixIndexT i = 1; i < dim; i++) { + sigma += (x[i]*s) * (x[i]*s); + v[i] = x[i]*s; + } + if (sigma == 0.0) *beta = 0.0; + else { + // When we say x1 = x[0], we reference the one-based indexing + // in Golub and Van Loan. + Real x1 = x[0] * s, mu = std::sqrt(x1*x1 + sigma); + if (x1 <= 0) { + v[0] = x1 - mu; + } else { + v[0] = -sigma / (x1 + mu); + KALDI_ASSERT(KALDI_ISFINITE(v[dim-1])); + } + Real v1 = v[0]; + Real v1sq = v1 * v1; + *beta = 2 * v1sq / (sigma + v1sq); + Real inv_v1 = 1.0 / v1; + if (KALDI_ISINF(inv_v1)) { + // can happen if v1 is denormal. + KALDI_ASSERT(v1 == v1 && v1 != 0.0); + for (MatrixIndexT i = 0; i < dim; i++) v[i] /= v1; + } else { + cblas_Xscal(dim, inv_v1, v, 1); + } + if (KALDI_ISNAN(inv_v1)) { + KALDI_ERR << "NaN encountered in HouseBackward"; + } + } +} +*/ + +// This is a backward version of the "House" routine above: +// backward because it's the last index, not the first index of +// the vector that is "special". This is convenient in +// the Tridiagonalize routine that uses reversed indexes for +// compatibility with the packed lower triangular format. +template +void HouseBackward(MatrixIndexT dim, const Real *x, Real *v, Real *beta) { + KALDI_ASSERT(dim > 0); + // To avoid overflow, we first compute the max of x_ (or + // one if that's zero, and we'll replace "x" by x/max(x_i) + // below. The householder vector is anyway invariant to + // the magnitude of x. We could actually avoid this extra loop + // over x if we wanted to be a bit smarter, but anyway this + // doesn't dominate the O(N) performance of the algorithm. + Real s; // s is a scale on x. + { + Real max_x = std::numeric_limits::min(); + for (MatrixIndexT i = 0; i < dim; i++) + max_x = std::max(max_x, (x[i] < 0 ? -x[i] : x[i])); + s = 1.0 / max_x; + } + Real sigma = 0.0; + v[dim-1] = 1.0; + for (MatrixIndexT i = 0; i + 1 < dim; i++) { + sigma += (x[i] * s) * (x[i] * s); + v[i] = x[i] * s; + } + KALDI_ASSERT(KALDI_ISFINITE(sigma) && + "Tridiagonalizing matrix that is too large or has NaNs."); + if (sigma == 0.0) *beta = 0.0; + else { + Real x1 = x[dim-1] * s, mu = std::sqrt(x1 * x1 + sigma); + if (x1 <= 0) { + v[dim-1] = x1 - mu; + } else { + v[dim-1] = -sigma / (x1 + mu); + KALDI_ASSERT(KALDI_ISFINITE(v[dim-1])); + } + Real v1 = v[dim-1]; + Real v1sq = v1 * v1; + *beta = 2 * v1sq / (sigma + v1sq); + Real inv_v1 = 1.0 / v1; + if (KALDI_ISINF(inv_v1)) { + // can happen if v1 is denormal. + KALDI_ASSERT(v1 == v1 && v1 != 0.0); + for (MatrixIndexT i = 0; i < dim; i++) v[i] /= v1; + } else { + cblas_Xscal(dim, inv_v1, v, 1); + } + if (KALDI_ISNAN(inv_v1)) { + KALDI_ERR << "NaN encountered in HouseBackward"; + } + } +} + + +/** + This routine tridiagonalizes *this. C.f. Golub and Van Loan 3rd ed., sec. + 8.3.1 (p415). We reverse the order of the indices as it's more natural + with packed lower-triangular matrices to do it this way. There's also + a shift from one-based to zero-based indexing, so the index + k is transformed k -> n - k, and a corresponding transpose... + + Let the original *this be A. This algorithms replaces *this with + a tridiagonal matrix T such that T = Q A Q^T for an orthogonal Q. + Caution: Q is transposed vs. Golub and Van Loan. + If Q != NULL it outputs Q. +*/ +template +void SpMatrix::Tridiagonalize(MatrixBase *Q) { + MatrixIndexT n = this->NumRows(); + KALDI_ASSERT(Q == NULL || (Q->NumRows() == n && + Q->NumCols() == n)); + if (Q != NULL) Q->SetUnit(); + Real *data = this->Data(); + Real *qdata = (Q == NULL ? NULL : Q->Data()); + MatrixIndexT qstride = (Q == NULL ? 0 : Q->Stride()); + Vector tmp_v(n-1), tmp_p(n); + Real beta, *v = tmp_v.Data(), *p = tmp_p.Data(), *w = p, *x = p; + for (MatrixIndexT k = n-1; k >= 2; k--) { + MatrixIndexT ksize = ((k+1)*k)/2; + // ksize is the packed size of the lower-triangular matrix of size k, + // which is the size of "all rows previous to this one." + Real *Arow = data + ksize; // In Golub+Van Loan it was A(k+1:n, k), we + // have Arow = A(k, 0:k-1). + HouseBackward(k, Arow, v, &beta); // sets v and beta. + cblas_Xspmv(k, beta, data, v, 1, 0.0, p, 1); // p = beta * A(0:k-1,0:k-1) v + Real minus_half_beta_pv = -0.5 * beta * cblas_Xdot(k, p, 1, v, 1); + cblas_Xaxpy(k, minus_half_beta_pv, v, 1, w, 1); // w = p - (beta p^T v/2) v; + // this relies on the fact that w and p are the same pointer. + // We're doing A(k, k-1) = ||Arow||. It happens that this element + // is indexed at ksize + k - 1 in the packed lower-triangular format. + data[ksize + k - 1] = std::sqrt(cblas_Xdot(k, Arow, 1, Arow, 1)); + for (MatrixIndexT i = 0; i + 1 < k; i++) + data[ksize + i] = 0; // This is not in Golub and Van Loan but is + // necessary if we're not using parts of A to store the Householder + // vectors. + // We're doing A(0:k-1,0:k-1) -= (v w' + w v') + cblas_Xspr2(k, -1.0, v, 1, w, 1, data); + if (Q != NULL) { // C.f. Golub, Q is H_1 .. H_n-2... in this + // case we apply them in the opposite order so it's H_n-1 .. H_1, + // but also Q is transposed so we really have Q = H_1 .. H_n-1. + // It's a double negative. + // Anyway, we left-multiply Q by each one. The H_n would each be + // diag(I + beta v v', I) but we don't ever touch the last dims. + // We do (in Matlab notation): + // Q(0:k-1,:) = (I - beta v v') * Q, i.e.: + // Q(:,0:i-1) += -beta v (v' Q(:,0:k-1)v .. let x = -beta Q(0:k-1,:)^T v. + cblas_Xgemv(kTrans, k, n, -beta, qdata, qstride, v, 1, 0.0, x, 1); + // now x = -beta Q(:,0:k-1) v. + // The next line does: Q(:,0:k-1) += v x'. + cblas_Xger(k, n, 1.0, v, 1, x, 1, qdata, qstride); + } + } +} + +// Instantiate these functions, as it wasn't implemented in sp-matrix.cc +// where we instantiated the whole class. +template +void SpMatrix::Tridiagonalize(MatrixBase *Q); +template +void SpMatrix::Tridiagonalize(MatrixBase *Q); + +/// Create Givens rotations, as in Golub and Van Loan 3rd ed., page 216. +template +inline void Givens(Real a, Real b, Real *c, Real *s) { + if (b == 0) { + *c = 1; + *s = 0; + } else { + if (std::abs(b) > std::abs(a)) { + Real tau = -a / b; + *s = 1 / std::sqrt(1 + tau*tau); + *c = *s * tau; + } else { + Real tau = -b / a; + *c = 1 / std::sqrt(1 + tau*tau); + *s = *c * tau; + } + } +} + + +// Some internal code for the QR algorithm: one "QR step". +// This is Golub and Van Loan 3rd ed., Algorithm 8.3.2 "Implicit Symmetric QR step +// with Wilkinson shift." A couple of differences: this code is +// in zero based arithmetic, and we represent Q transposed from +// their Q for memory locality with row-major-indexed matrices. +template +void QrStep(MatrixIndexT n, + Real *diag, + Real *off_diag, + MatrixBase *Q) { + KALDI_ASSERT(n >= 2); + // below, "scale" could be any number; we introduce it to keep the + // floating point quantities within a good range. + Real d = (diag[n-2] - diag[n-1]) / 2.0, + t = off_diag[n-2], + inv_scale = std::max(std::max(std::abs(d), std::abs(t)), + std::numeric_limits::min()), + scale = 1.0 / inv_scale, + d_scaled = d * scale, + off_diag_n2_scaled = off_diag[n-2] * scale, + t2_n_n1_scaled = off_diag_n2_scaled * off_diag_n2_scaled, + sgn_d = (d > 0.0 ? 1.0 : -1.0), + mu = diag[n-1] - inv_scale * t2_n_n1_scaled / + (d_scaled + sgn_d * std::sqrt(d_scaled * d_scaled + t2_n_n1_scaled)), + x = diag[0] - mu, + z = off_diag[0]; + KALDI_ASSERT(KALDI_ISFINITE(x)); + Real *Qdata = (Q == NULL ? NULL : Q->Data()); + MatrixIndexT Qstride = (Q == NULL ? 0 : Q->Stride()), + Qcols = (Q == NULL ? 0 : Q->NumCols()); + for (MatrixIndexT k = 0; k < n-1; k++) { + Real c, s; + Givens(x, z, &c, &s); + // Rotate dimensions k and k+1 with the Givens matrix G, as + // T <== G^T T G. + // In 2d, a Givens matrix is [ c s; -s c ]. Forget about + // the dimension-indexing issues and assume we have a 2x2 + // symmetric matrix [ p q ; q r ] + // We ask our friends at Wolfram Alpha about + // { { c, -s}, {s, c} } * { {p, q}, {q, r} } * { { c, s}, {-s, c} } + // Interpreting the result as [ p', q' ; q', r ] + // p' = c (c p - s q) - s (c q - s r) + // q' = s (c p - s q) + c (c q - s r) + // r' = s (s p + c q) + c (s q + c r) + Real p = diag[k], q = off_diag[k], r = diag[k+1]; + // p is element k,k; r is element k+1,k+1; q is element k,k+1 or k+1,k. + // We'll let the compiler optimize this. + diag[k] = c * (c*p - s*q) - s * (c*q - s*r); + off_diag[k] = s * (c*p - s*q) + c * (c*q - s*r); + diag[k+1] = s * (s*p + c*q) + c * (s*q + c*r); + + // We also have some other elements to think of that + // got rotated in a simpler way: if k>0, + // then element (k, k-1) and (k+1, k-1) get rotated. Here, + // element k+1, k-1 will be present as z; it's the out-of-band + // element that we remembered from last time. This is + // on the left as it's the row indexes that differ, so think of + // this as being premultiplied by G^T. In fact we're multiplying + // T by in some sense the opposite/transpose of the Givens rotation. + if (k > 0) { // Note, in rotations, going backward, (x,y) -> ((cx - sy), (sx + cy)) + Real &elem_k_km1 = off_diag[k-1], + elem_kp1_km1 = z; // , tmp = elem_k_km1; + elem_k_km1 = c*elem_k_km1 - s*elem_kp1_km1; + // The next line will set elem_kp1_km1 to zero and we'll never access this + // value, so we comment it out. + // elem_kp1_km1 = s*tmp + c*elem_kp1_km1; + } + if (Q != NULL) + cblas_Xrot(Qcols, Qdata + k*Qstride, 1, + Qdata + (k+1)*Qstride, 1, c, -s); + if (k < n-2) { + // Next is the elements (k+2, k) and (k+2, k-1), to be rotated, again + // backwards. + Real &elem_kp2_k = z, + &elem_kp2_kp1 = off_diag[k+1]; + // Note: elem_kp2_k == z would start off as zero because it's + // two off the diagonal, and not been touched yet. Therefore + // we eliminate it in expressions below, commenting it out. + // If we didn't do this we should set it to zero first. + elem_kp2_k = - s * elem_kp2_kp1; // + c*elem_kp2_k + elem_kp2_kp1 = c * elem_kp2_kp1; // + s*elem_kp2_k (original value). + // The next part is from the algorithm they describe: x = t_{k+1,k} + x = off_diag[k]; + } + } +} + + +// Internal code for the QR algorithm, where the diagonal +// and off-diagonal of the symmetric matrix are represented as +// vectors of length n and n-1. +template +void QrInternal(MatrixIndexT n, + Real *diag, + Real *off_diag, + MatrixBase *Q) { + KALDI_ASSERT(Q == NULL || Q->NumCols() == n); // We may + // later relax the condition that Q->NumCols() == n. + + MatrixIndexT counter = 0, max_iters = 500 + 4*n, // Should never take this many iters. + large_iters = 100 + 2*n; + Real epsilon = (pow(2.0, sizeof(Real) == 4 ? -23.0 : -52.0)); + + for (; counter < max_iters; counter++) { // this takes the place of "until + // q=n"... we'll break out of the + // loop when we converge. + if (counter == large_iters || + (counter > large_iters && (counter - large_iters) % 50 == 0)) { + KALDI_WARN << "Took " << counter + << " iterations in QR (dim is " << n << "), doubling epsilon."; + SubVector d(diag, n), o(off_diag, n-1); + KALDI_WARN << "Diag, off-diag are " << d << " and " << o; + epsilon *= 2.0; + } + for (MatrixIndexT i = 0; i+1 < n; i++) { + if (std::abs(off_diag[i]) <= epsilon * + (std::abs(diag[i]) + std::abs(diag[i+1]))) + off_diag[i] = 0.0; + } + // The next code works out p, q, and npq which is n - p - q. + // For the definitions of q and p, see Golub and Van Loan; we + // partition the n dims into pieces of size (p, n-p-q, q) where + // the part of size q is diagonal and the part of size n-p-p is + // "unreduced", i.e. has no zero off-diagonal elements. + MatrixIndexT q = 0; + // Note: below, "n-q < 2" should more clearly be "n-2-q < 0", but that + // causes problems if MatrixIndexT is unsigned. + while (q < n && (n-q < 2 || off_diag[n-2-q] == 0.0)) + q++; + if (q == n) break; // we're done. It's diagonal. + KALDI_ASSERT(n - q >= 2); + MatrixIndexT npq = 2; // Value of n - p - q, where n - p - q must be + // unreduced. This is the size of "middle" band of elements. If q != n, + // we must have hit a nonzero off-diag element, so the size of this + // band must be at least two. + while (npq + q < n && (n-q-npq-1 < 0 || off_diag[n-q-npq-1] != 0.0)) + npq++; + MatrixIndexT p = n - q - npq; + { // Checks. + for (MatrixIndexT i = 0; i+1 < npq; i++) + KALDI_ASSERT(off_diag[p + i] != 0.0); + for (MatrixIndexT i = 0; i+1 < q; i++) + KALDI_ASSERT(off_diag[p + npq - 1 + i] == 0.0); + if (p > 1) // Something must have stopped npq from growing further.. + KALDI_ASSERT(off_diag[p-1] == 0.0); // so last off-diag elem in + // group of size p must be zero. + } + + if (Q != NULL) { + // Do one QR step on the middle part of Q only. + // Qpart will be a subset of the rows of Q. + SubMatrix Qpart(*Q, p, npq, 0, Q->NumCols()); + QrStep(npq, diag + p, off_diag + p, &Qpart); + } else { + QrStep(npq, diag + p, off_diag + p, + static_cast*>(NULL)); + } + } + if (counter == max_iters) { + KALDI_WARN << "Failure to converge in QR algorithm. " + << "Exiting with partial output."; + } +} + + +/** + This is the symmetric QR algorithm, from Golub and Van Loan 3rd ed., Algorithm + 8.3.3. Q is transposed w.r.t. there, though. +*/ +template +void SpMatrix::Qr(MatrixBase *Q) { + KALDI_ASSERT(this->IsTridiagonal()); + // We envisage that Q would be square but we don't check for this, + // as there are situations where you might not want this. + KALDI_ASSERT(Q == NULL || Q->NumRows() == this->NumRows()); + // Note: the first couple of lines of the algorithm they give would be done + // outside of this function, by calling Tridiagonalize(). + + MatrixIndexT n = this->NumRows(); + Vector diag(n), off_diag(n-1); + for (MatrixIndexT i = 0; i < n; i++) { + diag(i) = (*this)(i, i); + if (i > 0) off_diag(i-1) = (*this)(i, i-1); + } + QrInternal(n, diag.Data(), off_diag.Data(), Q); + // Now set *this to the value represented by diag and off_diag. + this->SetZero(); + for (MatrixIndexT i = 0; i < n; i++) { + (*this)(i, i) = diag(i); + if (i > 0) (*this)(i, i-1) = off_diag(i-1); + } +} + +template +void SpMatrix::Eig(VectorBase *s, MatrixBase *P) const { + MatrixIndexT dim = this->NumRows(); + KALDI_ASSERT(s->Dim() == dim); + KALDI_ASSERT(P == NULL || (P->NumRows() == dim && P->NumCols() == dim)); + + SpMatrix A(*this); // Copy *this, since the tridiagonalization + // and QR decomposition are destructive. + // Note: for efficiency of memory access, the tridiagonalization + // algorithm makes the *rows* of P the eigenvectors, not the columns. + // We'll transpose P before we exit. + // Also note: P may be null if you don't want the eigenvectors. This + // will make this function more efficient. + + A.Tridiagonalize(P); // Tridiagonalizes. + A.Qr(P); // Diagonalizes. + if(P) P->Transpose(); + s->CopyDiagFromPacked(A); +} + + +template +void SpMatrix::TopEigs(VectorBase *s, MatrixBase *P, + MatrixIndexT lanczos_dim) const { + const SpMatrix &S(*this); // call this "S" for easy notation. + MatrixIndexT eig_dim = s->Dim(); // Space of dim we want to retain. + if (lanczos_dim <= 0) + lanczos_dim = std::max(eig_dim + 50, eig_dim + eig_dim/2); + MatrixIndexT dim = this->NumRows(); + if (lanczos_dim >= dim) { + // There would be no speed advantage in using this method, so just + // use the regular approach. + Vector s_tmp(dim); + Matrix P_tmp(dim, dim); + this->Eig(&s_tmp, &P_tmp); + SortSvd(&s_tmp, &P_tmp); + s->CopyFromVec(s_tmp.Range(0, eig_dim)); + P->CopyFromMat(P_tmp.Range(0, dim, 0, eig_dim)); + return; + } + KALDI_ASSERT(eig_dim <= dim && eig_dim > 0); + KALDI_ASSERT(P->NumRows() == dim && P->NumCols() == eig_dim); // each column + // is one eigenvector. + + Matrix Q(lanczos_dim, dim); // The rows of Q will be the + // orthogonal vectors of the Krylov subspace. + + SpMatrix T(lanczos_dim); // This will be equal to Q S Q^T, + // i.e. *this projected into the Krylov subspace. Note: only the + // diagonal and off-diagonal fo T are nonzero, i.e. it's tridiagonal, + // but we don't have access to the low-level algorithms that work + // on that type of matrix (since we want to use ATLAS). So we just + // do normal SVD, on a full matrix; it won't typically dominate. + + Q.Row(0).SetRandn(); + Q.Row(0).Scale(1.0 / Q.Row(0).Norm(2)); + for (MatrixIndexT d = 0; d < lanczos_dim; d++) { + Vector r(dim); + r.AddSpVec(1.0, S, Q.Row(d), 0.0); + // r = S * q_d + MatrixIndexT counter = 0; + Real end_prod; + while (1) { // Normally we'll do this loop only once: + // we repeat to handle cases where r gets very much smaller + // and we want to orthogonalize again. + // We do "full orthogonalization" to preserve stability, + // even though this is usually a waste of time. + Real start_prod = VecVec(r, r); + for (SignedMatrixIndexT e = d; e >= 0; e--) { // e must be signed! + SubVector q_e(Q, e); + Real prod = VecVec(r, q_e); + if (counter == 0 && static_cast(e) + 1 >= d) // Keep T tridiagonal, which + T(d, e) = prod; // mathematically speaking, it is. + r.AddVec(-prod, q_e); // Subtract component in q_e. + } + if (d+1 == lanczos_dim) break; + end_prod = VecVec(r, r); + if (end_prod <= 0.1 * start_prod) { + // also handles case where both are 0. + // We're not confident any more that it's completely + // orthogonal to the rest so we want to re-do. + if (end_prod == 0.0) + r.SetRandn(); // "Restarting". + counter++; + if (counter > 100) + KALDI_ERR << "Loop detected in Lanczos iteration."; + } else { + break; + } + } + if (d+1 != lanczos_dim) { + // OK, at this point we're satisfied that r is orthogonal + // to all previous rows. + KALDI_ASSERT(end_prod != 0.0); // should have looped. + r.Scale(1.0 / std::sqrt(end_prod)); // make it unit. + Q.Row(d+1).CopyFromVec(r); + } + } + + Matrix R(lanczos_dim, lanczos_dim); + R.SetUnit(); + T.Qr(&R); // Diagonalizes T. + Vector s_tmp(lanczos_dim); + s_tmp.CopyDiagFromSp(T); + + // Now T = R * diag(s_tmp) * R^T. + // The next call sorts the elements of s from greatest to least absolute value, + // and moves around the rows of R in the corresponding way. This picks out + // the largest (absolute) eigenvalues. + SortSvd(&s_tmp, static_cast*>(NULL), &R); + // Keep only the initial rows of R, those corresponding to greatest (absolute) + // eigenvalues. + SubMatrix Rsub(R, 0, eig_dim, 0, lanczos_dim); + SubVector s_sub(s_tmp, 0, eig_dim); + s->CopyFromVec(s_sub); + + // For working out what to do now, just assume the other eigenvalues were + // zero. This is just for purposes of knowing how to get the result, and + // not getting things wrongly transposed. + // We have T = Rsub^T * diag(s_sub) * Rsub. + // Now, T = Q S Q^T, with Q orthogonal, so S = Q^T T Q = Q^T Rsub^T * diag(s) * Rsub * Q. + // The output is P and we want S = P * diag(s) * P^T, so we need P = Q^T Rsub^T. + P->AddMatMat(1.0, Q, kTrans, Rsub, kTrans, 0.0); +} + + +// Instantiate the templates for Eig and TopEig. +template +void SpMatrix::Eig(VectorBase*, MatrixBase*) const; +template +void SpMatrix::Eig(VectorBase*, MatrixBase*) const; + +template +void SpMatrix::TopEigs(VectorBase*, MatrixBase*, MatrixIndexT) const; +template +void SpMatrix::TopEigs(VectorBase*, MatrixBase*, MatrixIndexT) const; + +// Someone had a problem with the Intel compiler with -O3, with Qr not being +// defined for some strange reason (should automatically happen when +// we instantiate Eig and TopEigs), so we explicitly instantiate it here. +template +void SpMatrix::Qr(MatrixBase *Q); +template +void SpMatrix::Qr(MatrixBase *Q); + + + +} +// namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/sp-matrix-inl.h b/speechx/speechx/kaldi/matrix/sp-matrix-inl.h new file mode 100644 index 00000000..15795923 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/sp-matrix-inl.h @@ -0,0 +1,42 @@ +// matrix/sp-matrix-inl.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_SP_MATRIX_INL_H_ +#define KALDI_MATRIX_SP_MATRIX_INL_H_ + +#include "matrix/tp-matrix.h" + +namespace kaldi { + +// All the lines in this file seem to be declaring template specializations. +// These tell the compiler that we'll implement the templated function +// separately for the different template arguments (float, double). + +template<> +double SolveQuadraticProblem(const SpMatrix &H, const VectorBase &g, + const SolverOptions &opts, VectorBase *x); + +template<> +float SolveQuadraticProblem(const SpMatrix &H, const VectorBase &g, + const SolverOptions &opts, VectorBase *x); + +} // namespace kaldi + + +#endif // KALDI_MATRIX_SP_MATRIX_INL_H_ diff --git a/speechx/speechx/kaldi/matrix/sp-matrix.cc b/speechx/speechx/kaldi/matrix/sp-matrix.cc new file mode 100644 index 00000000..224ef39f --- /dev/null +++ b/speechx/speechx/kaldi/matrix/sp-matrix.cc @@ -0,0 +1,1216 @@ +// matrix/sp-matrix.cc + +// Copyright 2009-2011 Lukas Burget; Ondrej Glembek; Microsoft Corporation +// Saarland University; Petr Schwarz; Yanmin Qian; +// Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "matrix/sp-matrix.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/matrix-functions.h" +#include "matrix/cblas-wrappers.h" + +namespace kaldi { + +// **************************************************************************** +// Returns the log-determinant if +ve definite, else KALDI_ERR. +// **************************************************************************** +template +Real SpMatrix::LogPosDefDet() const { + TpMatrix chol(this->NumRows()); + double det = 0.0; + double diag; + chol.Cholesky(*this); // Will throw exception if not +ve definite! + + for (MatrixIndexT i = 0; i < this->NumRows(); i++) { + diag = static_cast(chol(i, i)); + det += kaldi::Log(diag); + } + return static_cast(2*det); +} + + +template +void SpMatrix::Swap(SpMatrix *other) { + std::swap(this->data_, other->data_); + std::swap(this->num_rows_, other->num_rows_); +} + +template +void SpMatrix::SymPosSemiDefEig(VectorBase *s, + MatrixBase *P, + Real tolerance) const { + Eig(s, P); + Real max = s->Max(), min = s->Min(); + KALDI_ASSERT(-min <= tolerance * max); + s->ApplyFloor(0.0); +} + +template +Real SpMatrix::MaxAbsEig() const { + Vector s(this->NumRows()); + this->Eig(&s, static_cast*>(NULL)); + return std::max(s.Max(), -s.Min()); +} + +// returns true if positive definite--uses cholesky. +template +bool SpMatrix::IsPosDef() const { + MatrixIndexT D = (*this).NumRows(); + KALDI_ASSERT(D > 0); + try { + TpMatrix C(D); + C.Cholesky(*this); + for (MatrixIndexT r = 0; r < D; r++) + if (C(r, r) == 0.0) return false; + return true; + } + catch(...) { // not positive semidefinite. + return false; + } +} + +template +void SpMatrix::ApplyPow(Real power) { + if (power == 1) return; // can do nothing. + MatrixIndexT D = this->NumRows(); + KALDI_ASSERT(D > 0); + Matrix U(D, D); + Vector l(D); + (*this).SymPosSemiDefEig(&l, &U); + + Vector l_copy(l); + try { + l.ApplyPow(power * 0.5); + } + catch(...) { + KALDI_ERR << "Error taking power " << (power * 0.5) << " of vector " + << l_copy; + } + U.MulColsVec(l); + (*this).AddMat2(1.0, U, kNoTrans, 0.0); +} + +template +void SpMatrix::CopyFromMat(const MatrixBase &M, + SpCopyType copy_type) { + KALDI_ASSERT(this->NumRows() == M.NumRows() && M.NumRows() == M.NumCols()); + MatrixIndexT D = this->NumRows(); + + switch (copy_type) { + case kTakeMeanAndCheck: + { + Real good_sum = 0.0, bad_sum = 0.0; + for (MatrixIndexT i = 0; i < D; i++) { + for (MatrixIndexT j = 0; j < i; j++) { + Real a = M(i, j), b = M(j, i), avg = 0.5*(a+b), diff = 0.5*(a-b); + (*this)(i, j) = avg; + good_sum += std::abs(avg); + bad_sum += std::abs(diff); + } + good_sum += std::abs(M(i, i)); + (*this)(i, i) = M(i, i); + } + if (bad_sum > 0.01 * good_sum) { + KALDI_ERR << "SpMatrix::Copy(), source matrix is not symmetric: " + << bad_sum << ">" << good_sum; + } + break; + } + case kTakeMean: + { + for (MatrixIndexT i = 0; i < D; i++) { + for (MatrixIndexT j = 0; j < i; j++) { + (*this)(i, j) = 0.5*(M(i, j) + M(j, i)); + } + (*this)(i, i) = M(i, i); + } + break; + } + case kTakeLower: + { // making this one a bit more efficient. + const Real *src = M.Data(); + Real *dest = this->data_; + MatrixIndexT stride = M.Stride(); + for (MatrixIndexT i = 0; i < D; i++) { + for (MatrixIndexT j = 0; j <= i; j++) + dest[j] = src[j]; + dest += i + 1; + src += stride; + } + } + break; + case kTakeUpper: + for (MatrixIndexT i = 0; i < D; i++) + for (MatrixIndexT j = 0; j <= i; j++) + (*this)(i, j) = M(j, i); + break; + default: + KALDI_ASSERT("Invalid argument to SpMatrix::CopyFromMat"); + } +} + +template +Real SpMatrix::Trace() const { + const Real *data = this->data_; + MatrixIndexT num_rows = this->num_rows_; + Real ans = 0.0; + for (int32 i = 1; i <= num_rows; i++, data += i) + ans += *data; + return ans; +} + +// diagonal update, this <-- this + diag(v) +template +template +void SpMatrix::AddDiagVec(const Real alpha, const VectorBase &v) { + int32 num_rows = this->num_rows_; + KALDI_ASSERT(num_rows == v.Dim() && num_rows > 0); + const OtherReal *src = v.Data(); + Real *dst = this->data_; + if (alpha == 1.0) + for (int32 i = 1; i <= num_rows; i++, src++, dst += i) + *dst += *src; + else + for (int32 i = 1; i <= num_rows; i++, src++, dst += i) + *dst += alpha * *src; +} + +// instantiate the template above. +template +void SpMatrix::AddDiagVec(const float alpha, + const VectorBase &v); + +template +void SpMatrix::AddDiagVec(const double alpha, + const VectorBase &v); + +template +void SpMatrix::AddDiagVec(const float alpha, + const VectorBase &v); + +template +void SpMatrix::AddDiagVec(const double alpha, + const VectorBase &v); + +template<> +template<> +void SpMatrix::AddVec2(const double alpha, const VectorBase &v); + +#ifndef HAVE_ATLAS +template +void SpMatrix::Invert(Real *logdet, Real *det_sign, bool need_inverse) { + // these are CLAPACK types + KaldiBlasInt result; + KaldiBlasInt rows = static_cast(this->num_rows_); + KaldiBlasInt* p_ipiv = new KaldiBlasInt[rows]; + Real *p_work; // workspace for the lapack function + void *temp; + if ((p_work = static_cast( + KALDI_MEMALIGN(16, sizeof(Real) * rows, &temp))) == NULL) { + delete[] p_ipiv; + throw std::bad_alloc(); + } +#ifdef HAVE_OPENBLAS + memset(p_work, 0, sizeof(Real) * rows); // gets rid of a probably + // spurious Valgrind warning about jumps depending upon uninitialized values. +#endif + + + // NOTE: Even though "U" is for upper, lapack assumes column-wise storage + // of the data. We have a row-wise storage, therefore, we need to "invert" + clapack_Xsptrf(&rows, this->data_, p_ipiv, &result); + + + KALDI_ASSERT(result >= 0 && "Call to CLAPACK ssptrf_ called with wrong arguments"); + + if (result > 0) { // Singular... + if (det_sign) *det_sign = 0; + if (logdet) *logdet = -std::numeric_limits::infinity(); + if (need_inverse) KALDI_ERR << "CLAPACK stptrf_ : factorization failed"; + } else { // Not singular.. compute log-determinant if needed. + if (logdet != NULL || det_sign != NULL) { + Real prod = 1.0, log_prod = 0.0; + int sign = 1; + for (int i = 0; i < (int)this->num_rows_; i++) { + if (p_ipiv[i] > 0) { // not a 2x2 block... + // if (p_ipiv[i] != i+1) sign *= -1; // row swap. + Real diag = (*this)(i, i); + prod *= diag; + } else { // negative: 2x2 block. [we are in first of the two]. + i++; // skip over the first of the pair. + // each 2x2 block... + Real diag1 = (*this)(i, i), diag2 = (*this)(i-1, i-1), + offdiag = (*this)(i, i-1); + Real thisdet = diag1*diag2 - offdiag*offdiag; + // thisdet == determinant of 2x2 block. + // The following line is more complex than it looks: there are 2 offsets of + // 1 that cancel. + prod *= thisdet; + } + if (i == (int)(this->num_rows_-1) || fabs(prod) < 1.0e-10 || fabs(prod) > 1.0e+10) { + if (prod < 0) { prod = -prod; sign *= -1; } + log_prod += kaldi::Log(std::abs(prod)); + prod = 1.0; + } + } + if (logdet != NULL) *logdet = log_prod; + if (det_sign != NULL) *det_sign = sign; + } + } + if (!need_inverse) { + delete [] p_ipiv; + KALDI_MEMALIGN_FREE(p_work); + return; // Don't need what is computed next. + } + // NOTE: Even though "U" is for upper, lapack assumes column-wise storage + // of the data. We have a row-wise storage, therefore, we need to "invert" + clapack_Xsptri(&rows, this->data_, p_ipiv, p_work, &result); + + KALDI_ASSERT(result >=0 && + "Call to CLAPACK ssptri_ called with wrong arguments"); + + if (result != 0) { + KALDI_ERR << "CLAPACK ssptrf_ : Matrix is singular"; + } + + delete [] p_ipiv; + KALDI_MEMALIGN_FREE(p_work); +} +#else +// in the ATLAS case, these are not implemented using a library and we back off to something else. +template +void SpMatrix::Invert(Real *logdet, Real *det_sign, bool need_inverse) { + Matrix M(this->NumRows(), this->NumCols()); + M.CopyFromSp(*this); + M.Invert(logdet, det_sign, need_inverse); + if (need_inverse) + for (MatrixIndexT i = 0; i < this->NumRows(); i++) + for (MatrixIndexT j = 0; j <= i; j++) + (*this)(i, j) = M(i, j); +} +#endif + +template +void SpMatrix::InvertDouble(Real *logdet, Real *det_sign, + bool inverse_needed) { + SpMatrix dmat(*this); + double logdet_tmp, det_sign_tmp; + dmat.Invert(logdet ? &logdet_tmp : NULL, + det_sign ? &det_sign_tmp : NULL, + inverse_needed); + if (logdet) *logdet = logdet_tmp; + if (det_sign) *det_sign = det_sign_tmp; + (*this).CopyFromSp(dmat); +} + + + +double TraceSpSp(const SpMatrix &A, const SpMatrix &B) { + KALDI_ASSERT(A.NumRows() == B.NumRows()); + const double *Aptr = A.Data(); + const double *Bptr = B.Data(); + MatrixIndexT R = A.NumRows(); + MatrixIndexT RR = (R * (R + 1)) / 2; + double all_twice = 2.0 * cblas_Xdot(RR, Aptr, 1, Bptr, 1); + // "all_twice" contains twice the vector-wise dot-product... this is + // what we want except the diagonal elements are represented + // twice. + double diag_once = 0.0; + for (MatrixIndexT row_plus_two = 2; row_plus_two <= R + 1; row_plus_two++) { + diag_once += *Aptr * *Bptr; + Aptr += row_plus_two; + Bptr += row_plus_two; + } + return all_twice - diag_once; +} + + +float TraceSpSp(const SpMatrix &A, const SpMatrix &B) { + KALDI_ASSERT(A.NumRows() == B.NumRows()); + const float *Aptr = A.Data(); + const float *Bptr = B.Data(); + MatrixIndexT R = A.NumRows(); + MatrixIndexT RR = (R * (R + 1)) / 2; + float all_twice = 2.0 * cblas_Xdot(RR, Aptr, 1, Bptr, 1); + // "all_twice" contains twice the vector-wise dot-product... this is + // what we want except the diagonal elements are represented + // twice. + float diag_once = 0.0; + for (MatrixIndexT row_plus_two = 2; row_plus_two <= R + 1; row_plus_two++) { + diag_once += *Aptr * *Bptr; + Aptr += row_plus_two; + Bptr += row_plus_two; + } + return all_twice - diag_once; +} + + +template +Real TraceSpSp(const SpMatrix &A, const SpMatrix &B) { + KALDI_ASSERT(A.NumRows() == B.NumRows()); + Real ans = 0.0; + const Real *Aptr = A.Data(); + const OtherReal *Bptr = B.Data(); + MatrixIndexT row, col, R = A.NumRows(); + for (row = 0; row < R; row++) { + for (col = 0; col < row; col++) + ans += 2.0 * *(Aptr++) * *(Bptr++); + ans += *(Aptr++) * *(Bptr++); // Diagonal. + } + return ans; +} + +template +float TraceSpSp(const SpMatrix &A, const SpMatrix &B); + +template +double TraceSpSp(const SpMatrix &A, const SpMatrix &B); + + +template +Real TraceSpMat(const SpMatrix &A, const MatrixBase &B) { + KALDI_ASSERT(A.NumRows() == B.NumRows() && A.NumCols() == B.NumCols() && + "KALDI_ERR: TraceSpMat: arguments have mismatched dimension"); + MatrixIndexT R = A.NumRows(); + Real ans = (Real)0.0; + const Real *Aptr = A.Data(), *Bptr = B.Data(); + MatrixIndexT bStride = B.Stride(); + for (MatrixIndexT r = 0;r < R;r++) { + for (MatrixIndexT c = 0;c < r;c++) { + // ans += A(r, c) * (B(r, c) + B(c, r)); + ans += *(Aptr++) * (Bptr[r*bStride + c] + Bptr[c*bStride + r]); + } + // ans += A(r, r) * B(r, r); + ans += *(Aptr++) * Bptr[r*bStride + r]; + } + return ans; +} + +template +float TraceSpMat(const SpMatrix &A, const MatrixBase &B); + +template +double TraceSpMat(const SpMatrix &A, const MatrixBase &B); + + +template +Real TraceMatSpMat(const MatrixBase &A, MatrixTransposeType transA, + const SpMatrix &B, const MatrixBase &C, + MatrixTransposeType transC) { + KALDI_ASSERT((transA == kTrans?A.NumCols():A.NumRows()) == + (transC == kTrans?C.NumRows():C.NumCols()) && + (transA == kTrans?A.NumRows():A.NumCols()) == B.NumRows() && + (transC == kTrans?C.NumCols():C.NumRows()) == B.NumRows() && + "TraceMatSpMat: arguments have wrong dimension."); + Matrix tmp(B.NumRows(), B.NumRows()); + tmp.AddMatMat(1.0, C, transC, A, transA, 0.0); // tmp = C * A. + return TraceSpMat(B, tmp); +} + +template +float TraceMatSpMat(const MatrixBase &A, MatrixTransposeType transA, + const SpMatrix &B, const MatrixBase &C, + MatrixTransposeType transC); +template +double TraceMatSpMat(const MatrixBase &A, MatrixTransposeType transA, + const SpMatrix &B, const MatrixBase &C, + MatrixTransposeType transC); + +template +Real TraceMatSpMatSp(const MatrixBase &A, MatrixTransposeType transA, + const SpMatrix &B, const MatrixBase &C, + MatrixTransposeType transC, const SpMatrix &D) { + KALDI_ASSERT((transA == kTrans ?A.NumCols():A.NumRows() == D.NumCols()) && + (transA == kTrans ? A.NumRows():A.NumCols() == B.NumRows()) && + (transC == kTrans ? A.NumCols():A.NumRows() == B.NumCols()) && + (transC == kTrans ? A.NumRows():A.NumCols() == D.NumRows()) && + "KALDI_ERR: TraceMatSpMatSp: arguments have mismatched dimension."); + // Could perhaps optimize this more depending on dimensions of quantities. + Matrix tmpAB(transA == kTrans ? A.NumCols():A.NumRows(), B.NumCols()); + tmpAB.AddMatSp(1.0, A, transA, B, 0.0); + Matrix tmpCD(transC == kTrans ? C.NumCols():C.NumRows(), D.NumCols()); + tmpCD.AddMatSp(1.0, C, transC, D, 0.0); + return TraceMatMat(tmpAB, tmpCD, kNoTrans); +} + +template +float TraceMatSpMatSp(const MatrixBase &A, MatrixTransposeType transA, + const SpMatrix &B, const MatrixBase &C, + MatrixTransposeType transC, const SpMatrix &D); +template +double TraceMatSpMatSp(const MatrixBase &A, MatrixTransposeType transA, + const SpMatrix &B, const MatrixBase &C, + MatrixTransposeType transC, const SpMatrix &D); + + +template +bool SpMatrix::IsDiagonal(Real cutoff) const { + MatrixIndexT R = this->NumRows(); + Real bad_sum = 0.0, good_sum = 0.0; + for (MatrixIndexT i = 0; i < R; i++) { + for (MatrixIndexT j = 0; j <= i; j++) { + if (i == j) + good_sum += std::abs((*this)(i, j)); + else + bad_sum += std::abs((*this)(i, j)); + } + } + return (!(bad_sum > good_sum * cutoff)); +} + +template +bool SpMatrix::IsUnit(Real cutoff) const { + MatrixIndexT R = this->NumRows(); + Real max = 0.0; // max error + for (MatrixIndexT i = 0; i < R; i++) + for (MatrixIndexT j = 0; j <= i; j++) + max = std::max(max, static_cast(std::abs((*this)(i, j) - + (i == j ? 1.0 : 0.0)))); + return (max <= cutoff); +} + +template +bool SpMatrix::IsTridiagonal(Real cutoff) const { + MatrixIndexT R = this->NumRows(); + Real max_abs_2diag = 0.0, max_abs_offdiag = 0.0; + for (MatrixIndexT i = 0; i < R; i++) + for (MatrixIndexT j = 0; j <= i; j++) { + if (j+1 < i) + max_abs_offdiag = std::max(max_abs_offdiag, + std::abs((*this)(i, j))); + else + max_abs_2diag = std::max(max_abs_2diag, + std::abs((*this)(i, j))); + } + return (max_abs_offdiag <= cutoff * max_abs_2diag); +} + +template +bool SpMatrix::IsZero(Real cutoff) const { + if (this->num_rows_ == 0) return true; + return (this->Max() <= cutoff && this->Min() >= -cutoff); +} + +template +Real SpMatrix::FrobeniusNorm() const { + Real sum = 0.0; + MatrixIndexT R = this->NumRows(); + for (MatrixIndexT i = 0; i < R; i++) { + for (MatrixIndexT j = 0; j < i; j++) + sum += (*this)(i, j) * (*this)(i, j) * 2; + sum += (*this)(i, i) * (*this)(i, i); + } + return std::sqrt(sum); +} + +template +bool SpMatrix::ApproxEqual(const SpMatrix &other, float tol) const { + if (this->NumRows() != other.NumRows()) + KALDI_ERR << "SpMatrix::AproxEqual, size mismatch, " + << this->NumRows() << " vs. " << other.NumRows(); + SpMatrix tmp(*this); + tmp.AddSp(-1.0, other); + return (tmp.FrobeniusNorm() <= tol * std::max(this->FrobeniusNorm(), other.FrobeniusNorm())); +} + +// function Floor: A = Floor(B, alpha * C) ... see tutorial document. +template +int SpMatrix::ApplyFloor(const SpMatrix &C, Real alpha, + bool verbose) { + MatrixIndexT dim = this->NumRows(); + int nfloored = 0; + KALDI_ASSERT(C.NumRows() == dim); + KALDI_ASSERT(alpha > 0); + TpMatrix L(dim); + L.Cholesky(C); + L.Scale(std::sqrt(alpha)); // equivalent to scaling C by alpha. + TpMatrix LInv(L); + LInv.Invert(); + + SpMatrix D(dim); + { // D = L^{-1} * (*this) * L^{-T} + Matrix LInvFull(LInv); + D.AddMat2Sp(1.0, LInvFull, kNoTrans, (*this), 0.0); + } + + Vector l(dim); + Matrix U(dim, dim); + + D.Eig(&l, &U); + + if (verbose) { + KALDI_LOG << "ApplyFloor: flooring following diagonal to 1: " << l; + } + for (MatrixIndexT i = 0; i < l.Dim(); i++) { + if (l(i) < 1.0) { + nfloored++; + l(i) = 1.0; + } + } + l.ApplyPow(0.5); + U.MulColsVec(l); + D.AddMat2(1.0, U, kNoTrans, 0.0); + { // D' := U * diag(l') * U^T ... l'=floor(l, 1) + Matrix LFull(L); + (*this).AddMat2Sp(1.0, LFull, kNoTrans, D, 0.0); // A := L * D' * L^T + } + return nfloored; +} + +template +Real SpMatrix::LogDet(Real *det_sign) const { + Real log_det; + SpMatrix tmp(*this); + // false== output not needed (saves some computation). + tmp.Invert(&log_det, det_sign, false); + return log_det; +} + + +template +int SpMatrix::ApplyFloor(Real floor) { + MatrixIndexT Dim = this->NumRows(); + int nfloored = 0; + Vector s(Dim); + Matrix P(Dim, Dim); + (*this).Eig(&s, &P); + for (MatrixIndexT i = 0; i < Dim; i++) { + if (s(i) < floor) { + nfloored++; + s(i) = floor; + } + } + (*this).AddMat2Vec(1.0, P, kNoTrans, s, 0.0); + return nfloored; +} + +template +MatrixIndexT SpMatrix::LimitCond(Real maxCond, bool invert) { // e.g. maxCond = 1.0e+05. + MatrixIndexT Dim = this->NumRows(); + Vector s(Dim); + Matrix P(Dim, Dim); + (*this).SymPosSemiDefEig(&s, &P); + KALDI_ASSERT(maxCond > 1); + Real floor = s.Max() / maxCond; + if (floor < 0) floor = 0; + if (floor < 1.0e-40) { + KALDI_WARN << "LimitCond: limiting " << floor << " to 1.0e-40"; + floor = 1.0e-40; + } + MatrixIndexT nfloored = 0; + for (MatrixIndexT i = 0; i < Dim; i++) { + if (s(i) <= floor) nfloored++; + if (invert) + s(i) = 1.0 / std::sqrt(std::max(s(i), floor)); + else + s(i) = std::sqrt(std::max(s(i), floor)); + } + P.MulColsVec(s); + (*this).AddMat2(1.0, P, kNoTrans, 0.0); // (*this) = P*P^T. ... (*this) = P * floor(s) * P^T ... if P was original P. + return nfloored; +} + +void SolverOptions::Check() const { + KALDI_ASSERT(K>10 && eps<1.0e-10); +} + +template<> double SolveQuadraticProblem(const SpMatrix &H, + const VectorBase &g, + const SolverOptions &opts, + VectorBase *x) { + KALDI_ASSERT(H.NumRows() == g.Dim() && g.Dim() == x->Dim() && x->Dim() != 0); + opts.Check(); + MatrixIndexT dim = x->Dim(); + if (H.IsZero(0.0)) { + KALDI_WARN << "Zero quadratic term in quadratic vector problem for " + << opts.name << ": leaving it unchanged."; + return 0.0; + } + if (opts.diagonal_precondition) { + // We can re-cast the problem with a diagonal preconditioner to + // make H better-conditioned. + Vector H_diag(dim); + H_diag.CopyDiagFromSp(H); + H_diag.ApplyFloor(std::numeric_limits::min() * 1.0E+3); + Vector H_diag_sqrt(H_diag); + H_diag_sqrt.ApplyPow(0.5); + Vector H_diag_inv_sqrt(H_diag_sqrt); + H_diag_inv_sqrt.InvertElements(); + Vector x_scaled(*x); + x_scaled.MulElements(H_diag_sqrt); + Vector g_scaled(g); + g_scaled.MulElements(H_diag_inv_sqrt); + SpMatrix H_scaled(dim); + H_scaled.AddVec2Sp(1.0, H_diag_inv_sqrt, H, 0.0); + double ans; + SolverOptions new_opts(opts); + new_opts.diagonal_precondition = false; + ans = SolveQuadraticProblem(H_scaled, g_scaled, new_opts, &x_scaled); + x->CopyFromVec(x_scaled); + x->MulElements(H_diag_inv_sqrt); + return ans; + } + Vector gbar(g); + if (opts.optimize_delta) gbar.AddSpVec(-1.0, H, *x, 1.0); // gbar = g - H x + Matrix U(dim, dim); + Vector l(dim); + H.SymPosSemiDefEig(&l, &U); // does svd H = U L V^T and checks that H == U L U^T to within a tolerance. + // floor l. + double f = std::max(static_cast(opts.eps), l.Max() / opts.K); + MatrixIndexT nfloored = 0; + for (MatrixIndexT i = 0; i < dim; i++) { // floor l. + if (l(i) < f) { + nfloored++; + l(i) = f; + } + } + if (nfloored != 0 && opts.print_debug_output) { + KALDI_LOG << "Solving quadratic problem for " << opts.name + << ": floored " << nfloored<< " eigenvalues. "; + } + Vector tmp(dim); + tmp.AddMatVec(1.0, U, kTrans, gbar, 0.0); // tmp = U^T \bar{g} + tmp.DivElements(l); // divide each element of tmp by l: tmp = \tilde{L}^{-1} U^T \bar{g} + Vector delta(dim); + delta.AddMatVec(1.0, U, kNoTrans, tmp, 0.0); // delta = U tmp = U \tilde{L}^{-1} U^T \bar{g} + Vector &xhat(tmp); + xhat.CopyFromVec(delta); + if (opts.optimize_delta) xhat.AddVec(1.0, *x); // xhat = x + delta. + double auxf_before = VecVec(g, *x) - 0.5 * VecSpVec(*x, H, *x), + auxf_after = VecVec(g, xhat) - 0.5 * VecSpVec(xhat, H, xhat); + if (auxf_after < auxf_before) { // Reject change. + if (auxf_after < auxf_before - 1.0e-10 && opts.print_debug_output) + KALDI_WARN << "Optimizing vector auxiliary function for " + << opts.name<< ": auxf decreased " << auxf_before + << " to " << auxf_after << ", change is " + << (auxf_after-auxf_before); + return 0.0; + } else { + x->CopyFromVec(xhat); + return auxf_after - auxf_before; + } +} + +template<> float SolveQuadraticProblem(const SpMatrix &H, + const VectorBase &g, + const SolverOptions &opts, + VectorBase *x) { + KALDI_ASSERT(H.NumRows() == g.Dim() && g.Dim() == x->Dim() && x->Dim() != 0); + SpMatrix Hd(H); + Vector gd(g); + Vector xd(*x); + float ans = static_cast(SolveQuadraticProblem(Hd, gd, opts, &xd)); + x->CopyFromVec(xd); + return ans; +} + +// Maximizes the auxiliary function Q(x) = tr(M^T SigmaInv Y) - 0.5 tr(SigmaInv M Q M^T). +// Like a numerically stable version of M := Y Q^{-1}. +template +Real +SolveQuadraticMatrixProblem(const SpMatrix &Q, + const MatrixBase &Y, + const SpMatrix &SigmaInv, + const SolverOptions &opts, + MatrixBase *M) { + KALDI_ASSERT(Q.NumRows() == M->NumCols() && + SigmaInv.NumRows() == M->NumRows() && Y.NumRows() == M->NumRows() + && Y.NumCols() == M->NumCols() && M->NumCols() != 0); + opts.Check(); + MatrixIndexT rows = M->NumRows(), cols = M->NumCols(); + if (Q.IsZero(0.0)) { + KALDI_WARN << "Zero quadratic term in quadratic matrix problem for " + << opts.name << ": leaving it unchanged."; + return 0.0; + } + + if (opts.diagonal_precondition) { + // We can re-cast the problem with a diagonal preconditioner in the space + // of Q (columns of M). Helps to improve the condition of Q. + Vector Q_diag(cols); + Q_diag.CopyDiagFromSp(Q); + Q_diag.ApplyFloor(std::numeric_limits::min() * 1.0E+3); + Vector Q_diag_sqrt(Q_diag); + Q_diag_sqrt.ApplyPow(0.5); + Vector Q_diag_inv_sqrt(Q_diag_sqrt); + Q_diag_inv_sqrt.InvertElements(); + Matrix M_scaled(*M); + M_scaled.MulColsVec(Q_diag_sqrt); + Matrix Y_scaled(Y); + Y_scaled.MulColsVec(Q_diag_inv_sqrt); + SpMatrix Q_scaled(cols); + Q_scaled.AddVec2Sp(1.0, Q_diag_inv_sqrt, Q, 0.0); + Real ans; + SolverOptions new_opts(opts); + new_opts.diagonal_precondition = false; + ans = SolveQuadraticMatrixProblem(Q_scaled, Y_scaled, SigmaInv, + new_opts, &M_scaled); + M->CopyFromMat(M_scaled); + M->MulColsVec(Q_diag_inv_sqrt); + return ans; + } + + Matrix Ybar(Y); + if (opts.optimize_delta) { + Matrix Qfull(Q); + Ybar.AddMatMat(-1.0, *M, kNoTrans, Qfull, kNoTrans, 1.0); + } // Ybar = Y - M Q. + Matrix U(cols, cols); + Vector l(cols); + Q.SymPosSemiDefEig(&l, &U); // does svd Q = U L V^T and checks that Q == U L U^T to within a tolerance. + // floor l. + Real f = std::max(static_cast(opts.eps), l.Max() / opts.K); + MatrixIndexT nfloored = 0; + for (MatrixIndexT i = 0; i < cols; i++) { // floor l. + if (l(i) < f) { nfloored++; l(i) = f; } + } + if (nfloored != 0 && opts.print_debug_output) + KALDI_LOG << "Solving matrix problem for " << opts.name + << ": floored " << nfloored << " eigenvalues. "; + Matrix tmpDelta(rows, cols); + tmpDelta.AddMatMat(1.0, Ybar, kNoTrans, U, kNoTrans, 0.0); // tmpDelta = Ybar * U. + l.InvertElements(); KALDI_ASSERT(1.0/l.Max() != 0); // check not infinite. eps should take care of this. + tmpDelta.MulColsVec(l); // tmpDelta = Ybar * U * \tilde{L}^{-1} + + Matrix Delta(rows, cols); + Delta.AddMatMat(1.0, tmpDelta, kNoTrans, U, kTrans, 0.0); // Delta = Ybar * U * \tilde{L}^{-1} * U^T + + Real auxf_before, auxf_after; + SpMatrix MQM(rows); + Matrix &SigmaInvY(tmpDelta); + { Matrix SigmaInvFull(SigmaInv); SigmaInvY.AddMatMat(1.0, SigmaInvFull, kNoTrans, Y, kNoTrans, 0.0); } + { // get auxf_before. Q(x) = tr(M^T SigmaInv Y) - 0.5 tr(SigmaInv M Q M^T). + MQM.AddMat2Sp(1.0, *M, kNoTrans, Q, 0.0); + auxf_before = TraceMatMat(*M, SigmaInvY, kaldi::kTrans) - 0.5*TraceSpSp(SigmaInv, MQM); + } + + Matrix Mhat(Delta); + if (opts.optimize_delta) Mhat.AddMat(1.0, *M); // Mhat = Delta + M. + + { // get auxf_after. + MQM.AddMat2Sp(1.0, Mhat, kNoTrans, Q, 0.0); + auxf_after = TraceMatMat(Mhat, SigmaInvY, kaldi::kTrans) - 0.5*TraceSpSp(SigmaInv, MQM); + } + + if (auxf_after < auxf_before) { + if (auxf_after < auxf_before - 1.0e-10) + KALDI_WARN << "Optimizing matrix auxiliary function for " + << opts.name << ", auxf decreased " + << auxf_before << " to " << auxf_after << ", change is " + << (auxf_after-auxf_before); + return 0.0; + } else { + M->CopyFromMat(Mhat); + return auxf_after - auxf_before; + } +} + +template +Real SolveDoubleQuadraticMatrixProblem(const MatrixBase &G, + const SpMatrix &P1, + const SpMatrix &P2, + const SpMatrix &Q1, + const SpMatrix &Q2, + const SolverOptions &opts, + MatrixBase *M) { + KALDI_ASSERT(Q1.NumRows() == M->NumCols() && P1.NumRows() == M->NumRows() && + G.NumRows() == M->NumRows() && G.NumCols() == M->NumCols() && + M->NumCols() != 0 && Q2.NumRows() == M->NumCols() && + P2.NumRows() == M->NumRows()); + MatrixIndexT rows = M->NumRows(), cols = M->NumCols(); + // The following check should not fail as we stipulate P1, P2 and one of Q1 + // or Q2 must be +ve def and other Q1 or Q2 must be +ve semidef. + TpMatrix LInv(rows); + LInv.Cholesky(P1); + LInv.Invert(); // Will throw exception if fails. + SpMatrix S(rows); + Matrix LInvFull(LInv); + S.AddMat2Sp(1.0, LInvFull, kNoTrans, P2, 0.0); // S := L^{-1} P_2 L^{-T} + Matrix U(rows, rows); + Vector d(rows); + S.SymPosSemiDefEig(&d, &U); + Matrix T(rows, rows); + T.AddMatMat(1.0, U, kTrans, LInvFull, kNoTrans, 0.0); // T := U^T * L^{-1} + +#ifdef KALDI_PARANOID // checking mainly for errors in the code or math. + { + SpMatrix P1Trans(rows); + P1Trans.AddMat2Sp(1.0, T, kNoTrans, P1, 0.0); + KALDI_ASSERT(P1Trans.IsUnit(0.01)); + } + { + SpMatrix P2Trans(rows); + P2Trans.AddMat2Sp(1.0, T, kNoTrans, P2, 0.0); + KALDI_ASSERT(P2Trans.IsDiagonal(0.01)); + } +#endif + + Matrix TInv(T); + TInv.Invert(); + Matrix Gdash(rows, cols); + Gdash.AddMatMat(1.0, T, kNoTrans, G, kNoTrans, 0.0); // G' = T G + Matrix MdashOld(rows, cols); + MdashOld.AddMatMat(1.0, TInv, kTrans, *M, kNoTrans, 0.0); // M' = T^{-T} M + Matrix MdashNew(MdashOld); + Real objf_impr = 0.0; + for (MatrixIndexT n = 0; n < rows; n++) { + SpMatrix Qsum(Q1); + Qsum.AddSp(d(n), Q2); + SubVector mdash_n = MdashNew.Row(n); + SubVector gdash_n = Gdash.Row(n); + + Matrix QsumInv(Qsum); + try { + QsumInv.Invert(); + Real old_objf = VecVec(mdash_n, gdash_n) + - 0.5 * VecSpVec(mdash_n, Qsum, mdash_n); + mdash_n.AddMatVec(1.0, QsumInv, kNoTrans, gdash_n, 0.0); // m'_n := g'_n * (Q_1 + d_n Q_2)^{-1} + Real new_objf = VecVec(mdash_n, gdash_n) + - 0.5 * VecSpVec(mdash_n, Qsum, mdash_n); + if (new_objf < old_objf) { + if (new_objf < old_objf - 1.0e-05) { + KALDI_WARN << "In double quadratic matrix problem: objective " + "function decreasing during optimization of " << opts.name + << ", " << old_objf << "->" << new_objf << ", change is " + << (new_objf - old_objf); + KALDI_ERR << "Auxiliary function decreasing."; // Will be caught. + } else { // Reset to old value, didn't improve (very close to optimum). + MdashNew.Row(n).CopyFromVec(MdashOld.Row(n)); + } + } + objf_impr += new_objf - old_objf; + } + catch (...) { + KALDI_WARN << "Matrix inversion or optimization failed during double " + "quadratic problem, solving for" << opts.name + << ": trying more stable approach."; + objf_impr += SolveQuadraticProblem(Qsum, gdash_n, opts, &mdash_n); + } + } + M->AddMatMat(1.0, T, kTrans, MdashNew, kNoTrans, 0.0); // M := T^T M'. + return objf_impr; +} + +// rank-one update, this <-- this + alpha V V' +template<> +template<> +void SpMatrix::AddVec2(const float alpha, const VectorBase &v) { + KALDI_ASSERT(v.Dim() == this->NumRows()); + cblas_Xspr(v.Dim(), alpha, v.Data(), 1, + this->data_); +} + +template +void SpMatrix::AddVec2Sp(const Real alpha, const VectorBase &v, + const SpMatrix &S, const Real beta) { + KALDI_ASSERT(v.Dim() == this->NumRows() && S.NumRows() == this->NumRows()); + const Real *Sdata = S.Data(); + const Real *vdata = v.Data(); + Real *data = this->data_; + MatrixIndexT dim = this->num_rows_; + for (MatrixIndexT r = 0; r < dim; r++) + for (MatrixIndexT c = 0; c <= r; c++, Sdata++, data++) + *data = beta * *data + alpha * vdata[r] * vdata[c] * *Sdata; +} + + +// rank-one update, this <-- this + alpha V V' +template<> +template<> +void SpMatrix::AddVec2(const double alpha, const VectorBase &v) { + KALDI_ASSERT(v.Dim() == num_rows_); + cblas_Xspr(v.Dim(), alpha, v.Data(), 1, data_); +} + + +template +template +void SpMatrix::AddVec2(const Real alpha, const VectorBase &v) { + KALDI_ASSERT(v.Dim() == this->NumRows()); + Real *data = this->data_; + const OtherReal *v_data = v.Data(); + MatrixIndexT nr = this->num_rows_; + for (MatrixIndexT i = 0; i < nr; i++) + for (MatrixIndexT j = 0; j <= i; j++, data++) + *data += alpha * v_data[i] * v_data[j]; +} + +// instantiate the template above. +template +void SpMatrix::AddVec2(const float alpha, const VectorBase &v); +template +void SpMatrix::AddVec2(const double alpha, const VectorBase &v); + + +template +Real VecSpVec(const VectorBase &v1, const SpMatrix &M, + const VectorBase &v2) { + MatrixIndexT D = M.NumRows(); + KALDI_ASSERT(v1.Dim() == D && v1.Dim() == v2.Dim()); + Vector tmp_vec(D); + cblas_Xspmv(D, 1.0, M.Data(), v1.Data(), 1, 0.0, tmp_vec.Data(), 1); + return VecVec(tmp_vec, v2); +} + +template +float VecSpVec(const VectorBase &v1, const SpMatrix &M, + const VectorBase &v2); +template +double VecSpVec(const VectorBase &v1, const SpMatrix &M, + const VectorBase &v2); + + +template +void SpMatrix::AddMat2Sp( + const Real alpha, const MatrixBase &M, + MatrixTransposeType transM, const SpMatrix &A, const Real beta) { + if (transM == kNoTrans) { + KALDI_ASSERT(M.NumCols() == A.NumRows() && M.NumRows() == this->num_rows_); + } else { + KALDI_ASSERT(M.NumRows() == A.NumRows() && M.NumCols() == this->num_rows_); + } + Vector tmp_vec(A.NumRows()); + Real *tmp_vec_data = tmp_vec.Data(); + SpMatrix tmp_A; + const Real *p_A_data = A.Data(); + Real *p_row_data = this->Data(); + MatrixIndexT M_other_dim = (transM == kNoTrans ? M.NumCols() : M.NumRows()), + M_same_dim = (transM == kNoTrans ? M.NumRows() : M.NumCols()), + M_stride = M.Stride(), dim = this->NumRows(); + KALDI_ASSERT(M_same_dim == dim); + + const Real *M_data = M.Data(); + + if (this->Data() <= A.Data() + A.SizeInBytes() && + this->Data() + this->SizeInBytes() >= A.Data()) { + // Matrices A and *this overlap. Make copy of A + tmp_A.Resize(A.NumRows()); + tmp_A.CopyFromSp(A); + p_A_data = tmp_A.Data(); + } + + if (transM == kNoTrans) { + for (MatrixIndexT r = 0; r < dim; r++, p_row_data += r) { + cblas_Xspmv(A.NumRows(), 1.0, p_A_data, M.RowData(r), 1, 0.0, tmp_vec_data, 1); + cblas_Xgemv(transM, r+1, M_other_dim, alpha, M_data, M_stride, + tmp_vec_data, 1, beta, p_row_data, 1); + } + } else { + for (MatrixIndexT r = 0; r < dim; r++, p_row_data += r) { + cblas_Xspmv(A.NumRows(), 1.0, p_A_data, M.Data() + r, M.Stride(), 0.0, tmp_vec_data, 1); + cblas_Xgemv(transM, M_other_dim, r+1, alpha, M_data, M_stride, + tmp_vec_data, 1, beta, p_row_data, 1); + } + } +} + +template +void SpMatrix::AddSmat2Sp( + const Real alpha, const MatrixBase &M, + MatrixTransposeType transM, const SpMatrix &A, + const Real beta) { + KALDI_ASSERT((transM == kNoTrans && M.NumCols() == A.NumRows()) || + (transM == kTrans && M.NumRows() == A.NumRows())); + if (transM == kNoTrans) { + KALDI_ASSERT(M.NumCols() == A.NumRows() && M.NumRows() == this->num_rows_); + } else { + KALDI_ASSERT(M.NumRows() == A.NumRows() && M.NumCols() == this->num_rows_); + } + MatrixIndexT Adim = A.NumRows(), dim = this->num_rows_; + + Matrix temp_A(A); // represent A as full matrix. + Matrix temp_MA(dim, Adim); + temp_MA.AddSmatMat(1.0, M, transM, temp_A, kNoTrans, 0.0); + + // Next-- we want to do *this = alpha * temp_MA * M^T + beta * *this. + // To make it sparse vector multiplies, since M is sparse, we'd like + // to do: for each column c, (*this column c) += temp_MA * (M^T's column c.) + // [ignoring the alpha and beta here.] + // It's not convenient to process columns in the symmetric + // packed format because they don't have a constant stride. However, + // we can use the fact that temp_MA * M is symmetric, to just assign + // each row of *this instead of each column. + // So the final iteration is: + // for i = 0... dim-1, + // [the i'th row of *this] = beta * [the i'th row of *this] + alpha * + // temp_MA * [the i'th column of M]. + // Of course, we only process the first 0 ... i elements of this row, + // as that's all that are kept in the symmetric packed format. + + Matrix temp_this(*this); + Real *data = this->data_; + const Real *Mdata = M.Data(), *MAdata = temp_MA.Data(); + MatrixIndexT temp_MA_stride = temp_MA.Stride(), Mstride = M.Stride(); + + if (transM == kNoTrans) { + // The column of M^T corresponds to the rows of the supplied matrix. + for (MatrixIndexT i = 0; i < dim; i++, data += i) { + MatrixIndexT num_rows = i + 1, num_cols = Adim; + Xgemv_sparsevec(kNoTrans, num_rows, num_cols, alpha, MAdata, + temp_MA_stride, Mdata + (i * Mstride), 1, beta, data, 1); + } + } else { + // The column of M^T corresponds to the columns of the supplied matrix. + for (MatrixIndexT i = 0; i < dim; i++, data += i) { + MatrixIndexT num_rows = i + 1, num_cols = Adim; + Xgemv_sparsevec(kNoTrans, num_rows, num_cols, alpha, MAdata, + temp_MA_stride, Mdata + i, Mstride, beta, data, 1); + } + } +} + +template +void SpMatrix::AddMat2Vec(const Real alpha, + const MatrixBase &M, + MatrixTransposeType transM, + const VectorBase &v, + const Real beta) { + this->Scale(beta); + KALDI_ASSERT((transM == kNoTrans && this->NumRows() == M.NumRows() && + M.NumCols() == v.Dim()) || + (transM == kTrans && this->NumRows() == M.NumCols() && + M.NumRows() == v.Dim())); + + if (transM == kNoTrans) { + const Real *Mdata = M.Data(), *vdata = v.Data(); + Real *data = this->data_; + MatrixIndexT dim = this->NumRows(), mcols = M.NumCols(), + mstride = M.Stride(); + for (MatrixIndexT col = 0; col < mcols; col++, vdata++, Mdata += 1) + cblas_Xspr(dim, *vdata*alpha, Mdata, mstride, data); + } else { + const Real *Mdata = M.Data(), *vdata = v.Data(); + Real *data = this->data_; + MatrixIndexT dim = this->NumRows(), mrows = M.NumRows(), + mstride = M.Stride(); + for (MatrixIndexT row = 0; row < mrows; row++, vdata++, Mdata += mstride) + cblas_Xspr(dim, *vdata*alpha, Mdata, 1, data); + } +} + +template +void SpMatrix::AddMat2(const Real alpha, const MatrixBase &M, + MatrixTransposeType transM, const Real beta) { + KALDI_ASSERT((transM == kNoTrans && this->NumRows() == M.NumRows()) + || (transM == kTrans && this->NumRows() == M.NumCols())); + + // Cblas has no function *sprk (i.e. symmetric packed rank-k update), so we + // use as temporary storage a regular matrix of which we only access its lower + // triangle + + MatrixIndexT this_dim = this->NumRows(), + m_other_dim = (transM == kNoTrans ? M.NumCols() : M.NumRows()); + + if (this_dim == 0) return; + if (alpha == 0.0) { + if (beta != 1.0) this->Scale(beta); + return; + } + + Matrix temp_mat(*this); // wastefully copies upper triangle too, but this + // doesn't dominate O(N) time. + + // This function call is hard-coded to update the lower triangle. + cblas_Xsyrk(transM, this_dim, m_other_dim, alpha, M.Data(), + M.Stride(), beta, temp_mat.Data(), temp_mat.Stride()); + + this->CopyFromMat(temp_mat, kTakeLower); +} + +template +void SpMatrix::AddTp2Sp(const Real alpha, const TpMatrix &T, + MatrixTransposeType transM, const SpMatrix &A, + const Real beta) { + Matrix Tmat(T); + AddMat2Sp(alpha, Tmat, transM, A, beta); +} + +template +void SpMatrix::AddVecVec(const Real alpha, const VectorBase &v, + const VectorBase &w) { + int32 dim = this->NumRows(); + KALDI_ASSERT(dim == v.Dim() && dim == w.Dim() && dim > 0); + cblas_Xspr2(dim, alpha, v.Data(), 1, w.Data(), 1, this->data_); +} + + +template +void SpMatrix::AddTp2(const Real alpha, const TpMatrix &T, + MatrixTransposeType transM, const Real beta) { + Matrix Tmat(T); + AddMat2(alpha, Tmat, transM, beta); +} + + +// Explicit instantiation of the class. +// This needs to be after the definition of all the class member functions. + +template class SpMatrix; +template class SpMatrix; + + +template +Real TraceSpSpLower(const SpMatrix &A, const SpMatrix &B) { + MatrixIndexT adim = A.NumRows(); + KALDI_ASSERT(adim == B.NumRows()); + MatrixIndexT dim = (adim*(adim+1))/2; + return cblas_Xdot(dim, A.Data(), 1, B.Data(), 1); +} +// Instantiate the template above. +template +double TraceSpSpLower(const SpMatrix &A, const SpMatrix &B); +template +float TraceSpSpLower(const SpMatrix &A, const SpMatrix &B); + +// Instantiate the template above. +template float SolveQuadraticMatrixProblem(const SpMatrix &Q, + const MatrixBase &Y, + const SpMatrix &SigmaInv, + const SolverOptions &opts, + MatrixBase *M); +template double SolveQuadraticMatrixProblem(const SpMatrix &Q, + const MatrixBase &Y, + const SpMatrix &SigmaInv, + const SolverOptions &opts, + MatrixBase *M); + +// Instantiate the template above. +template float SolveDoubleQuadraticMatrixProblem( + const MatrixBase &G, + const SpMatrix &P1, + const SpMatrix &P2, + const SpMatrix &Q1, + const SpMatrix &Q2, + const SolverOptions &opts, + MatrixBase *M); + +template double SolveDoubleQuadraticMatrixProblem( + const MatrixBase &G, + const SpMatrix &P1, + const SpMatrix &P2, + const SpMatrix &Q1, + const SpMatrix &Q2, + const SolverOptions &opts, + MatrixBase *M); + + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/sp-matrix.h b/speechx/speechx/kaldi/matrix/sp-matrix.h new file mode 100644 index 00000000..26d9ad6f --- /dev/null +++ b/speechx/speechx/kaldi/matrix/sp-matrix.h @@ -0,0 +1,517 @@ +// matrix/sp-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Lukas Burget; +// Saarland University; Ariya Rastrow; Yanmin Qian; +// Jan Silovsky + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_SP_MATRIX_H_ +#define KALDI_MATRIX_SP_MATRIX_H_ + +#include +#include + +#include "matrix/packed-matrix.h" + +namespace kaldi { + + +/// \addtogroup matrix_group +/// @{ +template class SpMatrix; + + +/** + * @brief Packed symetric matrix class +*/ +template +class SpMatrix : public PackedMatrix { + friend class CuSpMatrix; + public: + // so it can use our assignment operator. + friend class std::vector >; + + SpMatrix(): PackedMatrix() {} + + /// Copy constructor from CUDA version of SpMatrix + /// This is defined in ../cudamatrix/cu-sp-matrix.h + + explicit SpMatrix(const CuSpMatrix &cu); + + explicit SpMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero) + : PackedMatrix(r, resize_type) {} + + SpMatrix(const SpMatrix &orig) + : PackedMatrix(orig) {} + + template + explicit SpMatrix(const SpMatrix &orig) + : PackedMatrix(orig) {} + +#ifdef KALDI_PARANOID + explicit SpMatrix(const MatrixBase & orig, + SpCopyType copy_type = kTakeMeanAndCheck) + : PackedMatrix(orig.NumRows(), kUndefined) { + CopyFromMat(orig, copy_type); + } +#else + explicit SpMatrix(const MatrixBase & orig, + SpCopyType copy_type = kTakeMean) + : PackedMatrix(orig.NumRows(), kUndefined) { + CopyFromMat(orig, copy_type); + } +#endif + + /// Shallow swap. + void Swap(SpMatrix *other); + + inline void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero) { + PackedMatrix::Resize(nRows, resize_type); + } + + void CopyFromSp(const SpMatrix &other) { + PackedMatrix::CopyFromPacked(other); + } + + template + void CopyFromSp(const SpMatrix &other) { + PackedMatrix::CopyFromPacked(other); + } + +#ifdef KALDI_PARANOID + void CopyFromMat(const MatrixBase &orig, + SpCopyType copy_type = kTakeMeanAndCheck); +#else // different default arg if non-paranoid mode. + void CopyFromMat(const MatrixBase &orig, + SpCopyType copy_type = kTakeMean); +#endif + + inline Real operator() (MatrixIndexT r, MatrixIndexT c) const { + // if column is less than row, then swap these as matrix is stored + // as upper-triangular... only allowed for const matrix object. + if (static_cast(c) > + static_cast(r)) + std::swap(c, r); + // c<=r now so don't have to check c. + KALDI_ASSERT(static_cast(r) < + static_cast(this->num_rows_)); + return *(this->data_ + (r*(r+1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + + inline Real &operator() (MatrixIndexT r, MatrixIndexT c) { + if (static_cast(c) > + static_cast(r)) + std::swap(c, r); + // c<=r now so don't have to check c. + KALDI_ASSERT(static_cast(r) < + static_cast(this->num_rows_)); + return *(this->data_ + (r * (r + 1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + + SpMatrix& operator=(const SpMatrix &other) { + PackedMatrix::operator=(other); + return *this; + } + + using PackedMatrix::Scale; + + /// matrix inverse. + /// if inverse_needed = false, will fill matrix with garbage. + /// (only useful if logdet wanted). + void Invert(Real *logdet = NULL, Real *det_sign= NULL, + bool inverse_needed = true); + + // Below routine does inversion in double precision, + // even for single-precision object. + void InvertDouble(Real *logdet = NULL, Real *det_sign = NULL, + bool inverse_needed = true); + + /// Returns maximum ratio of singular values. + inline Real Cond() const { + Matrix tmp(*this); + return tmp.Cond(); + } + + /// Takes matrix to a fraction power via Svd. + /// Will throw exception if matrix is not positive semidefinite + /// (to within a tolerance) + void ApplyPow(Real exponent); + + /// This is the version of SVD that we implement for symmetric positive + /// definite matrices. This exists for historical reasons; right now its + /// internal implementation is the same as Eig(). It computes the eigenvalue + /// decomposition (*this) = P * diag(s) * P^T with P orthogonal. Will throw + /// exception if input is not positive semidefinite to within a tolerance. + void SymPosSemiDefEig(VectorBase *s, MatrixBase *P, + Real tolerance = 0.001) const; + + /// Solves the symmetric eigenvalue problem: at end we should have (*this) = P + /// * diag(s) * P^T. We solve the problem using the symmetric QR method. + /// P may be NULL. + /// Implemented in qr.cc. + /// If you need the eigenvalues sorted, the function SortSvd declared in + /// kaldi-matrix is suitable. + void Eig(VectorBase *s, MatrixBase *P = NULL) const; + + /// This function gives you, approximately, the largest eigenvalues of the + /// symmetric matrix and the corresponding eigenvectors. (largest meaning, + /// further from zero). It does this by doing a SVD within the Krylov + /// subspace generated by this matrix and a random vector. This is + /// a form of the Lanczos method with complete reorthogonalization, followed + /// by SVD within a smaller dimension ("lanczos_dim"). + /// + /// If *this is m by m, s should be of dimension n and P should be of + /// dimension m by n, with n <= m. The *columns* of P are the approximate + /// eigenvectors; P * diag(s) * P^T would be a low-rank reconstruction of + /// *this. The columns of P will be orthogonal, and the elements of s will be + /// the eigenvalues of *this projected into that subspace, but beyond that + /// there are no exact guarantees. (This is because the convergence of this + /// method is statistical). Note: it only makes sense to use this + /// method if you are in very high dimension and n is substantially smaller + /// than m: for example, if you want the 100 top eigenvalues of a 10k by 10k + /// matrix. This function calls Rand() to initialize the lanczos + /// iterations and also for restarting. + /// If lanczos_dim is zero, it will default to the greater of: + /// s->Dim() + 50 or s->Dim() + s->Dim()/2, but not more than this->Dim(). + /// If lanczos_dim == this->Dim(), you might as well just call the function + /// Eig() since the result will be the same, and Eig() would be faster; the + /// whole point of this function is to reduce the dimension of the SVD + /// computation. + void TopEigs(VectorBase *s, MatrixBase *P, + MatrixIndexT lanczos_dim = 0) const; + + + /// Returns the maximum of the absolute values of any of the + /// eigenvalues. + Real MaxAbsEig() const; + + void PrintEigs(const char *name) { + Vector s((*this).NumRows()); + Matrix P((*this).NumRows(), (*this).NumCols()); + SymPosSemiDefEig(&s, &P); + KALDI_LOG << "PrintEigs: " << name << ": " << s; + } + + bool IsPosDef() const; // returns true if Cholesky succeeds. + void AddSp(const Real alpha, const SpMatrix &Ma) { + this->AddPacked(alpha, Ma); + } + + /// Computes log determinant but only for +ve-def matrices + /// (it uses Cholesky). + /// If matrix is not +ve-def, it will throw an exception + /// was LogPDDeterminant() + Real LogPosDefDet() const; + + Real LogDet(Real *det_sign = NULL) const; + + /// rank-one update, this <-- this + alpha v v' + template + void AddVec2(const Real alpha, const VectorBase &v); + + /// rank-two update, this <-- this + alpha (v w' + w v'). + void AddVecVec(const Real alpha, const VectorBase &v, + const VectorBase &w); + + /// Does *this = beta * *thi + alpha * diag(v) * S * diag(v) + void AddVec2Sp(const Real alpha, const VectorBase &v, + const SpMatrix &S, const Real beta); + + /// diagonal update, this <-- this + diag(v) + template + void AddDiagVec(const Real alpha, const VectorBase &v); + + /// rank-N update: + /// if (transM == kNoTrans) + /// (*this) = beta*(*this) + alpha * M * M^T, + /// or (if transM == kTrans) + /// (*this) = beta*(*this) + alpha * M^T * M + /// Note: beta used to default to 0.0. + void AddMat2(const Real alpha, const MatrixBase &M, + MatrixTransposeType transM, const Real beta); + + /// Extension of rank-N update: + /// this <-- beta*this + alpha * M * A * M^T. + /// (*this) and A are allowed to be the same. + /// If transM == kTrans, then we do it as M^T * A * M. + void AddMat2Sp(const Real alpha, const MatrixBase &M, + MatrixTransposeType transM, const SpMatrix &A, + const Real beta = 0.0); + + /// This is a version of AddMat2Sp specialized for when M is fairly sparse. + /// This was required for making the raw-fMLLR code efficient. + void AddSmat2Sp(const Real alpha, const MatrixBase &M, + MatrixTransposeType transM, const SpMatrix &A, + const Real beta = 0.0); + + /// The following function does: + /// this <-- beta*this + alpha * T * A * T^T. + /// (*this) and A are allowed to be the same. + /// If transM == kTrans, then we do it as alpha * T^T * A * T. + /// Currently it just calls AddMat2Sp, but if needed we + /// can implement it more efficiently. + void AddTp2Sp(const Real alpha, const TpMatrix &T, + MatrixTransposeType transM, const SpMatrix &A, + const Real beta = 0.0); + + /// The following function does: + /// this <-- beta*this + alpha * T * T^T. + /// (*this) and A are allowed to be the same. + /// If transM == kTrans, then we do it as alpha * T^T * T + /// Currently it just calls AddMat2, but if needed we + /// can implement it more efficiently. + void AddTp2(const Real alpha, const TpMatrix &T, + MatrixTransposeType transM, const Real beta = 0.0); + + /// Extension of rank-N update: + /// this <-- beta*this + alpha * M * diag(v) * M^T. + /// if transM == kTrans, then + /// this <-- beta*this + alpha * M^T * diag(v) * M. + void AddMat2Vec(const Real alpha, const MatrixBase &M, + MatrixTransposeType transM, const VectorBase &v, + const Real beta = 0.0); + + + /// Floors this symmetric matrix to the matrix + /// alpha * Floor, where the matrix Floor is positive + /// definite. + /// It is floored in the sense that after flooring, + /// x^T (*this) x >= x^T (alpha*Floor) x. + /// This is accomplished using an Svd. It will crash + /// if Floor is not positive definite. Returns the number of + /// elements that were floored. + int ApplyFloor(const SpMatrix &Floor, Real alpha = 1.0, + bool verbose = false); + + /// Floor: Given a positive semidefinite matrix, floors the eigenvalues + /// to the specified quantity. A previous version of this function had + /// a tolerance which is now no longer needed since we have code to + /// do the symmetric eigenvalue decomposition and no longer use the SVD + /// code for that purose. + int ApplyFloor(Real floor); + + bool IsDiagonal(Real cutoff = 1.0e-05) const; + bool IsUnit(Real cutoff = 1.0e-05) const; + bool IsZero(Real cutoff = 1.0e-05) const; + bool IsTridiagonal(Real cutoff = 1.0e-05) const; + + /// sqrt of sum of square elements. + Real FrobeniusNorm() const; + + /// Returns true if ((*this)-other).FrobeniusNorm() <= + /// tol*(*this).FrobeniusNorma() + bool ApproxEqual(const SpMatrix &other, float tol = 0.01) const; + + // LimitCond: + // Limits the condition of symmetric positive semidefinite matrix to + // a specified value + // by flooring all eigenvalues to a positive number which is some multiple + // of the largest one (or zero if there are no positive eigenvalues). + // Takes the condition number we are willing to accept, and floors + // eigenvalues to the largest eigenvalue divided by this. + // Returns #eigs floored or already equal to the floor. + // Throws exception if input is not positive definite. + // returns #floored. + MatrixIndexT LimitCond(Real maxCond = 1.0e+5, bool invert = false); + + // as LimitCond but all done in double precision. // returns #floored. + MatrixIndexT LimitCondDouble(Real maxCond = 1.0e+5, bool invert = false) { + SpMatrix dmat(*this); + MatrixIndexT ans = dmat.LimitCond(maxCond, invert); + (*this).CopyFromSp(dmat); + return ans; + } + Real Trace() const; + + /// Tridiagonalize the matrix with an orthogonal transformation. If + /// *this starts as S, produce T (and Q, if non-NULL) such that + /// T = Q A Q^T, i.e. S = Q^T T Q. Caution: this is the other way + /// round from most authors (it's more efficient in row-major indexing). + void Tridiagonalize(MatrixBase *Q); + + /// The symmetric QR algorithm. This will mostly be useful in internal code. + /// Typically, you will call this after Tridiagonalize(), on the same object. + /// When called, *this (call it A at this point) must be tridiagonal; at exit, + /// *this will be a diagonal matrix D that is similar to A via orthogonal + /// transformations. This algorithm right-multiplies Q by orthogonal + /// transformations. It turns *this from a tridiagonal into a diagonal matrix + /// while maintaining that (Q *this Q^T) has the same value at entry and exit. + /// At entry Q should probably be either NULL or orthogonal, but we don't check + /// this. + void Qr(MatrixBase *Q); + + private: + void EigInternal(VectorBase *s, MatrixBase *P, + Real tolerance, int recurse) const; +}; + +/// @} end of "addtogroup matrix_group" + +/// \addtogroup matrix_funcs_scalar +/// @{ + + +/// Returns tr(A B). +float TraceSpSp(const SpMatrix &A, const SpMatrix &B); +double TraceSpSp(const SpMatrix &A, const SpMatrix &B); + + +template +inline bool ApproxEqual(const SpMatrix &A, + const SpMatrix &B, Real tol = 0.01) { + return A.ApproxEqual(B, tol); +} + +template +inline void AssertEqual(const SpMatrix &A, + const SpMatrix &B, Real tol = 0.01) { + KALDI_ASSERT(ApproxEqual(A, B, tol)); +} + + + +/// Returns tr(A B). +template +Real TraceSpSp(const SpMatrix &A, const SpMatrix &B); + + + +// TraceSpSpLower is the same as Trace(A B) except the lower-diagonal elements +// are counted only once not twice as they should be. It is useful in certain +// optimizations. +template +Real TraceSpSpLower(const SpMatrix &A, const SpMatrix &B); + + +/// Returns tr(A B). +/// No option to transpose B because would make no difference. +template +Real TraceSpMat(const SpMatrix &A, const MatrixBase &B); + +/// Returns tr(A B C) +/// (A and C may be transposed as specified by transA and transC). +template +Real TraceMatSpMat(const MatrixBase &A, MatrixTransposeType transA, + const SpMatrix &B, const MatrixBase &C, + MatrixTransposeType transC); + +/// Returns tr (A B C D) +/// (A and C may be transposed as specified by transA and transB). +template +Real TraceMatSpMatSp(const MatrixBase &A, MatrixTransposeType transA, + const SpMatrix &B, const MatrixBase &C, + MatrixTransposeType transC, const SpMatrix &D); + +/** Computes v1^T * M * v2. Not as efficient as it could be where v1 == v2 + * (but no suitable blas routines available). + */ + +/// Returns \f$ v_1^T M v_2 \f$ +/// Not as efficient as it could be where v1 == v2. +template +Real VecSpVec(const VectorBase &v1, const SpMatrix &M, + const VectorBase &v2); + + +/// @} \addtogroup matrix_funcs_scalar + +/// \addtogroup matrix_funcs_misc +/// @{ + + +/// This class describes the options for maximizing various quadratic objective +/// functions. It's mostly as described in the SGMM paper "the subspace +/// Gaussian mixture model -- a structured model for speech recognition", but +/// the diagonal_precondition option is newly added, to handle problems where +/// different dimensions have very different scaling (we recommend to use the +/// option but it's set false for back compatibility). +struct SolverOptions { + BaseFloat K; // maximum condition number + BaseFloat eps; + std::string name; + bool optimize_delta; + bool diagonal_precondition; + bool print_debug_output; + explicit SolverOptions(const std::string &name): + K(1.0e+4), eps(1.0e-40), name(name), + optimize_delta(true), diagonal_precondition(false), + print_debug_output(true) { } + SolverOptions(): K(1.0e+4), eps(1.0e-40), name("[unknown]"), + optimize_delta(true), diagonal_precondition(false), + print_debug_output(true) { } + void Check() const; +}; + + +/// Maximizes the auxiliary function +/// \f[ Q(x) = x.g - 0.5 x^T H x \f] +/// using a numerically stable method. Like a numerically stable version of +/// \f$ x := Q^{-1} g. \f$ +/// Assumes H positive semidefinite. +/// Returns the objective-function change. + +template +Real SolveQuadraticProblem(const SpMatrix &H, + const VectorBase &g, + const SolverOptions &opts, + VectorBase *x); + + + +/// Maximizes the auxiliary function : +/// \f[ Q(x) = tr(M^T P Y) - 0.5 tr(P M Q M^T) \f] +/// Like a numerically stable version of \f$ M := Y Q^{-1} \f$. +/// Assumes Q and P positive semidefinite, and matrix dimensions match +/// enough to make expressions meaningful. +/// This is mostly as described in the SGMM paper "the subspace Gaussian mixture +/// model -- a structured model for speech recognition", but the +/// diagonal_precondition option is newly added, to handle problems +/// where different dimensions have very different scaling (we recommend to use +/// the option but it's set false for back compatibility). +template +Real SolveQuadraticMatrixProblem(const SpMatrix &Q, + const MatrixBase &Y, + const SpMatrix &P, + const SolverOptions &opts, + MatrixBase *M); + +/// Maximizes the auxiliary function : +/// \f[ Q(M) = tr(M^T G) -0.5 tr(P_1 M Q_1 M^T) -0.5 tr(P_2 M Q_2 M^T). \f] +/// Encountered in matrix update with a prior. We also apply a limit on the +/// condition but it should be less frequently necessary, and can be set larger. +template +Real SolveDoubleQuadraticMatrixProblem(const MatrixBase &G, + const SpMatrix &P1, + const SpMatrix &P2, + const SpMatrix &Q1, + const SpMatrix &Q2, + const SolverOptions &opts, + MatrixBase *M); + + +/// @} End of "addtogroup matrix_funcs_misc" + +} // namespace kaldi + + +// Including the implementation (now actually just includes some +// template specializations). +#include "matrix/sp-matrix-inl.h" + + +#endif // KALDI_MATRIX_SP_MATRIX_H_ diff --git a/speechx/speechx/kaldi/matrix/sparse-matrix.cc b/speechx/speechx/kaldi/matrix/sparse-matrix.cc new file mode 100644 index 00000000..68a61e17 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/sparse-matrix.cc @@ -0,0 +1,1296 @@ +// matrix/sparse-matrix.cc + +// Copyright 2015 Johns Hopkins University (author: Daniel Povey) +// 2015 Guoguo Chen +// 2017 Shiyin Kang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "matrix/sparse-matrix.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +template +std::pair* SparseVector::Data() { + if (pairs_.empty()) + return NULL; + else + return &(pairs_[0]); +} + +template +const std::pair* SparseVector::Data() const { + if (pairs_.empty()) + return NULL; + else + return &(pairs_[0]); +} + +template +Real SparseVector::Sum() const { + Real sum = 0; + for (int32 i = 0; i < pairs_.size(); ++i) { + sum += pairs_[i].second; + } + return sum; +} + +template +void SparseVector::Scale(Real alpha) { + for (int32 i = 0; i < pairs_.size(); ++i) + pairs_[i].second *= alpha; +} + +template +template +void SparseVector::CopyElementsToVec(VectorBase *vec) const { + KALDI_ASSERT(vec->Dim() == this->dim_); + vec->SetZero(); + OtherReal *other_data = vec->Data(); + typename std::vector >::const_iterator + iter = pairs_.begin(), end = pairs_.end(); + for (; iter != end; ++iter) + other_data[iter->first] = iter->second; +} +template +void SparseVector::CopyElementsToVec(VectorBase *vec) const; +template +void SparseVector::CopyElementsToVec(VectorBase *vec) const; +template +void SparseVector::CopyElementsToVec(VectorBase *vec) const; +template +void SparseVector::CopyElementsToVec(VectorBase *vec) const; + +template +template +void SparseVector::AddToVec(Real alpha, + VectorBase *vec) const { + KALDI_ASSERT(vec->Dim() == dim_); + OtherReal *other_data = vec->Data(); + typename std::vector >::const_iterator + iter = pairs_.begin(), end = pairs_.end(); + if (alpha == 1.0) { // treat alpha==1.0 case specially. + for (; iter != end; ++iter) + other_data[iter->first] += iter->second; + } else { + for (; iter != end; ++iter) + other_data[iter->first] += alpha * iter->second; + } +} + +template +void SparseVector::AddToVec(float alpha, VectorBase *vec) const; +template +void SparseVector::AddToVec(float alpha, VectorBase *vec) const; +template +void SparseVector::AddToVec(double alpha, VectorBase *vec) const; +template +void SparseVector::AddToVec(double alpha, + VectorBase *vec) const; + +template +template +void SparseVector::CopyFromSvec(const SparseVector &other) { + dim_ = other.Dim(); + pairs_.clear(); + if (dim_ == 0) return; + for (int32 i = 0; i < other.NumElements(); ++i) { + pairs_.push_back(std::make_pair( + other.GetElement(i).first, + static_cast(other.GetElement(i).second))); + } +} +template +void SparseVector::CopyFromSvec(const SparseVector &svec); +template +void SparseVector::CopyFromSvec(const SparseVector &svec); +template +void SparseVector::CopyFromSvec(const SparseVector &svec); +template +void SparseVector::CopyFromSvec(const SparseVector &svec); + + +template +SparseVector& SparseVector::operator = ( + const SparseVector &other) { + this->CopyFromSvec(other); + dim_ = other.dim_; + pairs_ = other.pairs_; + return *this; +} + +template +void SparseVector::Swap(SparseVector *other) { + pairs_.swap(other->pairs_); + std::swap(dim_, other->dim_); +} + +template +void SparseVector::Write(std::ostream &os, bool binary) const { + if (binary) { + WriteToken(os, binary, "SV"); + WriteBasicType(os, binary, dim_); + MatrixIndexT num_elems = pairs_.size(); + WriteBasicType(os, binary, num_elems); + typename std::vector >::const_iterator + iter = pairs_.begin(), end = pairs_.end(); + for (; iter != end; ++iter) { + WriteBasicType(os, binary, iter->first); + WriteBasicType(os, binary, iter->second); + } + } else { + // In text-mode, use a human-friendly, script-friendly format; + // format is "dim=5 [ 0 0.2 3 0.9 ] " + os << "dim=" << dim_ << " [ "; + typename std::vector >::const_iterator + iter = pairs_.begin(), end = pairs_.end(); + for (; iter != end; ++iter) + os << iter->first << ' ' << iter->second << ' '; + os << "] "; + } +} + + +template +void SparseVector::Read(std::istream &is, bool binary) { + if (binary) { + ExpectToken(is, binary, "SV"); + ReadBasicType(is, binary, &dim_); + KALDI_ASSERT(dim_ >= 0); + int32 num_elems; + ReadBasicType(is, binary, &num_elems); + KALDI_ASSERT(num_elems >= 0 && num_elems <= dim_); + pairs_.resize(num_elems); + typename std::vector >::iterator + iter = pairs_.begin(), end = pairs_.end(); + for (; iter != end; ++iter) { + ReadBasicType(is, binary, &(iter->first)); + ReadBasicType(is, binary, &(iter->second)); + } + } else { + // In text-mode, format is "dim=5 [ 0 0.2 3 0.9 ] + std::string str; + is >> str; + if (str.substr(0, 4) != "dim=") + KALDI_ERR << "Reading sparse vector, expected 'dim=xxx', got " << str; + std::string dim_str = str.substr(4, std::string::npos); + std::istringstream dim_istr(dim_str); + int32 dim = -1; + dim_istr >> dim; + if (dim < 0 || dim_istr.fail()) { + KALDI_ERR << "Reading sparse vector, expected 'dim=[int]', got " << str; + } + dim_ = dim; + is >> std::ws; + is >> str; + if (str != "[") + KALDI_ERR << "Reading sparse vector, expected '[', got " << str; + pairs_.clear(); + while (1) { + is >> std::ws; + if (is.peek() == ']') { + is.get(); + break; + } + MatrixIndexT i; + BaseFloat p; + is >> i >> p; + if (is.fail()) + KALDI_ERR << "Error reading sparse vector, expecting numbers."; + KALDI_ASSERT(i >= 0 && i < dim + && (pairs_.empty() || i > pairs_.back().first)); + pairs_.push_back(std::pair(i, p)); + } + } +} + + +namespace sparse_vector_utils { +template +struct CompareFirst { + inline bool operator() (const std::pair &p1, + const std::pair &p2) const { + return p1.first < p2.first; + } +}; +} + +template +SparseVector::SparseVector( + MatrixIndexT dim, const std::vector > &pairs): + dim_(dim), + pairs_(pairs) { + std::sort(pairs_.begin(), pairs_.end(), + sparse_vector_utils::CompareFirst()); + typename std::vector >::iterator + out = pairs_.begin(), in = out, end = pairs_.end(); + // special case: while there is nothing to be changed, skip over + // initial input (avoids unnecessary copying). + while (in + 1 < end && in[0].first != in[1].first && in[0].second != 0.0) { + in++; + out++; + } + while (in < end) { + // We reach this point only at the first element of + // each stretch of identical .first elements. + *out = *in; + ++in; + while (in < end && in->first == out->first) { + out->second += in->second; // this is the merge operation. + ++in; + } + if (out->second != Real(0.0)) // Don't keep zero elements. + out++; + } + pairs_.erase(out, end); + if (!pairs_.empty()) { + // range check. + KALDI_ASSERT(pairs_.front().first >= 0 && pairs_.back().first < dim_); + } +} + +template +void SparseVector::SetRandn(BaseFloat zero_prob) { + pairs_.clear(); + KALDI_ASSERT(zero_prob >= 0 && zero_prob <= 1.0); + for (MatrixIndexT i = 0; i < dim_; i++) + if (WithProb(1.0 - zero_prob)) + pairs_.push_back(std::pair(i, RandGauss())); +} + +template +void SparseVector::Resize(MatrixIndexT dim, + MatrixResizeType resize_type) { + if (resize_type != kCopyData || dim == 0) + pairs_.clear(); + KALDI_ASSERT(dim >= 0); + if (dim < dim_ && resize_type == kCopyData) + while (!pairs_.empty() && pairs_.back().first >= dim) + pairs_.pop_back(); + dim_ = dim; +} + +template +MatrixIndexT SparseMatrix::NumRows() const { + return rows_.size(); +} + +template +MatrixIndexT SparseMatrix::NumCols() const { + if (rows_.empty()) + return 0.0; + else + return rows_[0].Dim(); +} + +template +MatrixIndexT SparseMatrix::NumElements() const { + int32 num_elements = 0; + for (int32 i = 0; i < rows_.size(); ++i) { + num_elements += rows_[i].NumElements(); + } + return num_elements; +} + +template +SparseVector* SparseMatrix::Data() { + if (rows_.empty()) + return NULL; + else + return rows_.data(); +} + +template +const SparseVector* SparseMatrix::Data() const { + if (rows_.empty()) + return NULL; + else + return rows_.data(); +} + +template +Real SparseMatrix::Sum() const { + Real sum = 0; + for (int32 i = 0; i < rows_.size(); ++i) { + sum += rows_[i].Sum(); + } + return sum; +} + +template +Real SparseMatrix::FrobeniusNorm() const { + Real squared_sum = 0; + for (int32 i = 0; i < rows_.size(); ++i) { + const std::pair *row_data = rows_[i].Data(); + for (int32 j = 0; j < rows_[i].NumElements(); ++j) { + squared_sum += row_data[j].second * row_data[j].second; + } + } + return std::sqrt(squared_sum); +} + +template +template +void SparseMatrix::CopyToMat(MatrixBase *other, + MatrixTransposeType trans) const { + if (trans == kNoTrans) { + MatrixIndexT num_rows = rows_.size(); + KALDI_ASSERT(other->NumRows() == num_rows); + for (MatrixIndexT i = 0; i < num_rows; i++) { + SubVector vec(*other, i); + rows_[i].CopyElementsToVec(&vec); + } + } else { + OtherReal *other_col_data = other->Data(); + MatrixIndexT other_stride = other->Stride(), + num_rows = NumRows(), num_cols = NumCols(); + KALDI_ASSERT(num_rows == other->NumCols() && num_cols == other->NumRows()); + other->SetZero(); + for (MatrixIndexT row = 0; row < num_rows; row++, other_col_data++) { + const SparseVector &svec = rows_[row]; + MatrixIndexT num_elems = svec.NumElements(); + const std::pair *sdata = svec.Data(); + for (MatrixIndexT e = 0; e < num_elems; e++) + other_col_data[sdata[e].first * other_stride] = sdata[e].second; + } + } +} + +template +void SparseMatrix::CopyToMat(MatrixBase *other, + MatrixTransposeType trans) const; +template +void SparseMatrix::CopyToMat(MatrixBase *other, + MatrixTransposeType trans) const; +template +void SparseMatrix::CopyToMat(MatrixBase *other, + MatrixTransposeType trans) const; +template +void SparseMatrix::CopyToMat(MatrixBase *other, + MatrixTransposeType trans) const; + +template +void SparseMatrix::CopyElementsToVec(VectorBase *other) const { + KALDI_ASSERT(other->Dim() == NumElements()); + Real *dst_data = other->Data(); + int32 dst_index = 0; + for (int32 i = 0; i < rows_.size(); ++i) { + for (int32 j = 0; j < rows_[i].NumElements(); ++j) { + dst_data[dst_index] = + static_cast(rows_[i].GetElement(j).second); + dst_index++; + } + } +} + +template +template +void SparseMatrix::CopyFromSmat(const SparseMatrix &other, + MatrixTransposeType trans) { + if (trans == kNoTrans) { + rows_.resize(other.NumRows()); + if (rows_.size() == 0) + return; + for (int32 r = 0; r < rows_.size(); ++r) { + rows_[r].CopyFromSvec(other.Row(r)); + } + } else { + std::vector > > pairs( + other.NumCols()); + for (MatrixIndexT i = 0; i < other.NumRows(); ++i) { + for (int id = 0; id < other.Row(i).NumElements(); ++id) { + MatrixIndexT j = other.Row(i).GetElement(id).first; + Real v = static_cast(other.Row(i).GetElement(id).second); + pairs[j].push_back( { i, v }); + } + } + SparseMatrix temp(other.NumRows(), pairs); + Swap(&temp); + } +} +template +void SparseMatrix::CopyFromSmat(const SparseMatrix &other, + MatrixTransposeType trans); +template +void SparseMatrix::CopyFromSmat(const SparseMatrix &other, + MatrixTransposeType trans); +template +void SparseMatrix::CopyFromSmat(const SparseMatrix &other, + MatrixTransposeType trans); +template +void SparseMatrix::CopyFromSmat(const SparseMatrix &other, + MatrixTransposeType trans); + +template +void SparseMatrix::Write(std::ostream &os, bool binary) const { + if (binary) { + // Note: we can use the same marker for float and double SparseMatrix, + // because internally we use WriteBasicType and ReadBasicType to read the + // floats and doubles, and this will automatically take care of type + // conversion. + WriteToken(os, binary, "SM"); + int32 num_rows = rows_.size(); + WriteBasicType(os, binary, num_rows); + for (int32 row = 0; row < num_rows; row++) + rows_[row].Write(os, binary); + } else { + // The format is "rows=10 dim=20 [ 1 0.4 9 1.2 ] dim=20 [ 3 1.7 19 0.6 ] .. + // not 100% efficient, but easy to work with, and we can re-use the + // read/write code from SparseVector. + int32 num_rows = rows_.size(); + os << "rows=" << num_rows << " "; + for (int32 row = 0; row < num_rows; row++) + rows_[row].Write(os, binary); + os << "\n"; // Might make it a little more readable. + } +} + +template +void SparseMatrix::Read(std::istream &is, bool binary) { + if (binary) { + ExpectToken(is, binary, "SM"); + int32 num_rows; + ReadBasicType(is, binary, &num_rows); + KALDI_ASSERT(num_rows >= 0 && num_rows < 10000000); + rows_.resize(num_rows); + for (int32 row = 0; row < num_rows; row++) + rows_[row].Read(is, binary); + } else { + std::string str; + is >> str; + if (str.substr(0, 5) != "rows=") + KALDI_ERR << "Reading sparse matrix, expected 'rows=xxx', got " << str; + std::string rows_str = str.substr(5, std::string::npos); + std::istringstream rows_istr(rows_str); + int32 num_rows = -1; + rows_istr >> num_rows; + if (num_rows < 0 || rows_istr.fail()) { + KALDI_ERR << "Reading sparse vector, expected 'rows=[int]', got " << str; + } + rows_.resize(num_rows); + for (int32 row = 0; row < num_rows; row++) + rows_[row].Read(is, binary); + } +} + + +template +void SparseMatrix::AddToMat(BaseFloat alpha, + MatrixBase *other, + MatrixTransposeType trans) const { + if (trans == kNoTrans) { + MatrixIndexT num_rows = rows_.size(); + KALDI_ASSERT(other->NumRows() == num_rows); + for (MatrixIndexT i = 0; i < num_rows; i++) { + SubVector vec(*other, i); + rows_[i].AddToVec(alpha, &vec); + } + } else { + Real *other_col_data = other->Data(); + MatrixIndexT other_stride = other->Stride(), + num_rows = NumRows(), num_cols = NumCols(); + KALDI_ASSERT(num_rows == other->NumCols() && num_cols == other->NumRows()); + for (MatrixIndexT row = 0; row < num_rows; row++, other_col_data++) { + const SparseVector &svec = rows_[row]; + MatrixIndexT num_elems = svec.NumElements(); + const std::pair *sdata = svec.Data(); + for (MatrixIndexT e = 0; e < num_elems; e++) + other_col_data[sdata[e].first * other_stride] += + alpha * sdata[e].second; + } + } +} + +template +Real VecSvec(const VectorBase &vec, + const SparseVector &svec) { + KALDI_ASSERT(vec.Dim() == svec.Dim()); + MatrixIndexT n = svec.NumElements(); + const std::pair *sdata = svec.Data(); + const Real *data = vec.Data(); + Real ans = 0.0; + for (MatrixIndexT i = 0; i < n; i++) + ans += data[sdata[i].first] * sdata[i].second; + return ans; +} + +template +float VecSvec(const VectorBase &vec, + const SparseVector &svec); +template +double VecSvec(const VectorBase &vec, + const SparseVector &svec); + +template +const SparseVector &SparseMatrix::Row(MatrixIndexT r) const { + KALDI_ASSERT(static_cast(r) < rows_.size()); + return rows_[r]; +} + +template +void SparseMatrix::SetRow(int32 r, const SparseVector &vec) { + KALDI_ASSERT(static_cast(r) < rows_.size() && + vec.Dim() == rows_[0].Dim()); + rows_[r] = vec; +} + + +template +void SparseMatrix::SelectRows(const std::vector &row_indexes, + const SparseMatrix &smat_other) { + Resize(row_indexes.size(), smat_other.NumCols()); + for (int i = 0; i < row_indexes.size(); ++i) { + SetRow(i, smat_other.Row(row_indexes[i])); + } +} + +template +SparseMatrix::SparseMatrix(const std::vector &indexes, int32 dim, + MatrixTransposeType trans) { + const std::vector& idx = indexes; + std::vector > > pair(idx.size()); + for (int i = 0; i < idx.size(); ++i) { + if (idx[i] >= 0) { + pair[i].push_back( { idx[i], Real(1) }); + } + } + SparseMatrix smat_cpu(dim, pair); + if (trans == kNoTrans) { + this->Swap(&smat_cpu); + } else { + SparseMatrix tmp(smat_cpu, kTrans); + this->Swap(&tmp); + } +} + +template +SparseMatrix::SparseMatrix(const std::vector &indexes, + const VectorBase &weights, int32 dim, + MatrixTransposeType trans) { + const std::vector& idx = indexes; + const VectorBase& w = weights; + std::vector > > pair(idx.size()); + for (int i = 0; i < idx.size(); ++i) { + if (idx[i] >= 0) { + pair[i].push_back( { idx[i], w(i) }); + } + } + SparseMatrix smat_cpu(dim, pair); + if (trans == kNoTrans) { + this->Swap(&smat_cpu); + } else { + SparseMatrix tmp(smat_cpu, kTrans); + this->Swap(&tmp); + } +} + +template +SparseMatrix& SparseMatrix::operator = ( + const SparseMatrix &other) { + rows_ = other.rows_; + return *this; +} + +template +void SparseMatrix::Swap(SparseMatrix *other) { + rows_.swap(other->rows_); +} + +template +SparseMatrix::SparseMatrix( + MatrixIndexT dim, + const std::vector > > &pairs): + rows_(pairs.size()) { + MatrixIndexT num_rows = pairs.size(); + for (MatrixIndexT row = 0; row < num_rows; row++) { + SparseVector svec(dim, pairs[row]); + rows_[row].Swap(&svec); + } +} + +template +void SparseMatrix::SetRandn(BaseFloat zero_prob) { + MatrixIndexT num_rows = rows_.size(); + for (MatrixIndexT row = 0; row < num_rows; row++) + rows_[row].SetRandn(zero_prob); +} + +template +void SparseMatrix::Resize(MatrixIndexT num_rows, + MatrixIndexT num_cols, + MatrixResizeType resize_type) { + KALDI_ASSERT(num_rows >= 0 && num_cols >= 0); + if (resize_type == kSetZero || resize_type == kUndefined) { + rows_.clear(); + Resize(num_rows, num_cols, kCopyData); + } else { + // Assume resize_type == kCopyData from here. + int32 old_num_rows = rows_.size(), old_num_cols = NumCols(); + SparseVector initializer(num_cols); + rows_.resize(num_rows, initializer); + if (num_cols != old_num_cols) + for (int32 row = 0; row < old_num_rows; row++) + rows_[row].Resize(num_cols, kCopyData); + } +} + +template +void SparseMatrix::AppendSparseMatrixRows( + std::vector > *inputs) { + rows_.clear(); + size_t num_rows = 0; + typename std::vector >::iterator + input_iter = inputs->begin(), + input_end = inputs->end(); + for (; input_iter != input_end; ++input_iter) + num_rows += input_iter->rows_.size(); + rows_.resize(num_rows); + typename std::vector >::iterator + row_iter = rows_.begin(), + row_end = rows_.end(); + for (input_iter = inputs->begin(); input_iter != input_end; ++input_iter) { + typename std::vector >::iterator + input_row_iter = input_iter->rows_.begin(), + input_row_end = input_iter->rows_.end(); + for (; input_row_iter != input_row_end; ++input_row_iter, ++row_iter) + row_iter->Swap(&(*input_row_iter)); + } + KALDI_ASSERT(row_iter == row_end); + int32 num_cols = NumCols(); + for (row_iter = rows_.begin(); row_iter != row_end; ++row_iter) { + if (row_iter->Dim() != num_cols) + KALDI_ERR << "Appending rows with inconsistent dimensions, " + << row_iter->Dim() << " vs. " << num_cols; + } + inputs->clear(); +} + +template +void SparseMatrix::Scale(Real alpha) { + MatrixIndexT num_rows = rows_.size(); + for (MatrixIndexT row = 0; row < num_rows; row++) + rows_[row].Scale(alpha); +} + +template +SparseMatrix::SparseMatrix(const MatrixBase &mat) { + MatrixIndexT num_rows = mat.NumRows(); + rows_.resize(num_rows); + for (int32 row = 0; row < num_rows; row++) { + SparseVector this_row(mat.Row(row)); + rows_[row].Swap(&this_row); + } +} + +template +Real TraceMatSmat(const MatrixBase &A, + const SparseMatrix &B, + MatrixTransposeType trans) { + Real sum = 0.0; + if (trans == kTrans) { + MatrixIndexT num_rows = A.NumRows(); + KALDI_ASSERT(B.NumRows() == num_rows); + for (MatrixIndexT r = 0; r < num_rows; r++) + sum += VecSvec(A.Row(r), B.Row(r)); + } else { + const Real *A_col_data = A.Data(); + MatrixIndexT Astride = A.Stride(), Acols = A.NumCols(), Arows = A.NumRows(); + KALDI_ASSERT(Arows == B.NumCols() && Acols == B.NumRows()); + sum = 0.0; + for (MatrixIndexT i = 0; i < Acols; i++, A_col_data++) { + Real col_sum = 0.0; + const SparseVector &svec = B.Row(i); + MatrixIndexT num_elems = svec.NumElements(); + const std::pair *sdata = svec.Data(); + for (MatrixIndexT e = 0; e < num_elems; e++) + col_sum += A_col_data[Astride * sdata[e].first] * sdata[e].second; + sum += col_sum; + } + } + return sum; +} + +template +float TraceMatSmat(const MatrixBase &A, + const SparseMatrix &B, + MatrixTransposeType trans); +template +double TraceMatSmat(const MatrixBase &A, + const SparseMatrix &B, + MatrixTransposeType trans); + +void GeneralMatrix::Clear() { + mat_.Resize(0, 0); + cmat_.Clear(); + smat_.Resize(0, 0); +} + +GeneralMatrix& GeneralMatrix::operator= (const MatrixBase &mat) { + Clear(); + mat_ = mat; + return *this; +} + +GeneralMatrix& GeneralMatrix::operator= (const CompressedMatrix &cmat) { + Clear(); + cmat_ = cmat; + return *this; +} + +GeneralMatrix& GeneralMatrix::operator= (const SparseMatrix &smat) { + Clear(); + smat_ = smat; + return *this; +} + +GeneralMatrix& GeneralMatrix::operator= (const GeneralMatrix &gmat) { + mat_ = gmat.mat_; + smat_ = gmat.smat_; + cmat_ = gmat.cmat_; + return *this; +} + + +GeneralMatrixType GeneralMatrix::Type() const { + if (smat_.NumRows() != 0) + return kSparseMatrix; + else if (cmat_.NumRows() != 0) + return kCompressedMatrix; + else + return kFullMatrix; +} + +MatrixIndexT GeneralMatrix::NumRows() const { + MatrixIndexT r = smat_.NumRows(); + if (r != 0) + return r; + r = cmat_.NumRows(); + if (r != 0) + return r; + return mat_.NumRows(); +} + +MatrixIndexT GeneralMatrix::NumCols() const { + MatrixIndexT r = smat_.NumCols(); + if (r != 0) + return r; + r = cmat_.NumCols(); + if (r != 0) + return r; + return mat_.NumCols(); +} + + +void GeneralMatrix::Compress() { + if (mat_.NumRows() != 0) { + cmat_.CopyFromMat(mat_); + mat_.Resize(0, 0); + } +} + +void GeneralMatrix::Uncompress() { + if (cmat_.NumRows() != 0) { + mat_.Resize(cmat_.NumRows(), cmat_.NumCols(), kUndefined); + cmat_.CopyToMat(&mat_); + cmat_.Clear(); + } +} + +void GeneralMatrix::GetMatrix(Matrix *mat) const { + if (mat_.NumRows() !=0) { + *mat = mat_; + } else if (cmat_.NumRows() != 0) { + mat->Resize(cmat_.NumRows(), cmat_.NumCols(), kUndefined); + cmat_.CopyToMat(mat); + } else if (smat_.NumRows() != 0) { + mat->Resize(smat_.NumRows(), smat_.NumCols(), kUndefined); + smat_.CopyToMat(mat); + } else { + mat->Resize(0, 0); + } +} + +void GeneralMatrix::CopyToMat(MatrixBase *mat, + MatrixTransposeType trans) const { + if (mat_.NumRows() !=0) { + mat->CopyFromMat(mat_, trans); + } else if (cmat_.NumRows() != 0) { + cmat_.CopyToMat(mat, trans); + } else if (smat_.NumRows() != 0) { + smat_.CopyToMat(mat, trans); + } else { + KALDI_ASSERT(mat->NumRows() == 0); + } +} + +void GeneralMatrix::Scale(BaseFloat alpha) { + if (mat_.NumRows() != 0) { + mat_.Scale(alpha); + } else if (cmat_.NumRows() != 0) { + cmat_.Scale(alpha); + } else if (smat_.NumRows() != 0) { + smat_.Scale(alpha); + } + +} +const SparseMatrix& GeneralMatrix::GetSparseMatrix() const { + if (mat_.NumRows() != 0 || cmat_.NumRows() != 0) + KALDI_ERR << "GetSparseMatrix called on GeneralMatrix of wrong type."; + return smat_; +} + +void GeneralMatrix::SwapSparseMatrix(SparseMatrix *smat) { + if (mat_.NumRows() != 0 || cmat_.NumRows() != 0) + KALDI_ERR << "GetSparseMatrix called on GeneralMatrix of wrong type."; + smat->Swap(&smat_); +} + +void GeneralMatrix::SwapCompressedMatrix(CompressedMatrix *cmat) { + if (mat_.NumRows() != 0 || smat_.NumRows() != 0) + KALDI_ERR << "GetSparseMatrix called on GeneralMatrix of wrong type."; + cmat->Swap(&cmat_); +} + +const CompressedMatrix &GeneralMatrix::GetCompressedMatrix() const { + if (mat_.NumRows() != 0 || smat_.NumRows() != 0) + KALDI_ERR << "GetCompressedMatrix called on GeneralMatrix of wrong type."; + return cmat_; +} + +const Matrix &GeneralMatrix::GetFullMatrix() const { + if (smat_.NumRows() != 0 || cmat_.NumRows() != 0) + KALDI_ERR << "GetFullMatrix called on GeneralMatrix of wrong type."; + return mat_; +} + + +void GeneralMatrix::SwapFullMatrix(Matrix *mat) { + if (cmat_.NumRows() != 0 || smat_.NumRows() != 0) + KALDI_ERR << "SwapMatrix called on GeneralMatrix of wrong type."; + mat->Swap(&mat_); +} + +void GeneralMatrix::Write(std::ostream &os, bool binary) const { + if (smat_.NumRows() != 0) { + smat_.Write(os, binary); + } else if (cmat_.NumRows() != 0) { + cmat_.Write(os, binary); + } else { + mat_.Write(os, binary); + } +} + +void GeneralMatrix::Read(std::istream &is, bool binary) { + Clear(); + if (binary) { + int peekval = is.peek(); + if (peekval == 'C') { + // Token CM for compressed matrix + cmat_.Read(is, binary); + } else if (peekval == 'S') { + // Token SM for sparse matrix + smat_.Read(is, binary); + } else { + mat_.Read(is, binary); + } + } else { + // note: in text mode we will only ever read regular + // or sparse matrices, because the compressed-matrix format just + // gets written as a regular matrix in text mode. + is >> std::ws; // Eat up white space. + int peekval = is.peek(); + if (peekval == 'r') { // sparse format starts rows=[int]. + smat_.Read(is, binary); + } else { + mat_.Read(is, binary); + } + } +} + + +void AppendGeneralMatrixRows(const std::vector &src, + GeneralMatrix *mat) { + mat->Clear(); + int32 size = src.size(); + if (size == 0) + return; + bool all_sparse = true; + for (int32 i = 0; i < size; i++) { + if (src[i]->Type() != kSparseMatrix && src[i]->NumRows() != 0) { + all_sparse = false; + break; + } + } + if (all_sparse) { + std::vector > sparse_mats(size); + for (int32 i = 0; i < size; i++) + sparse_mats[i] = src[i]->GetSparseMatrix(); + SparseMatrix appended_mat; + appended_mat.AppendSparseMatrixRows(&sparse_mats); + mat->SwapSparseMatrix(&appended_mat); + } else { + int32 tot_rows = 0, num_cols = -1; + for (int32 i = 0; i < size; i++) { + const GeneralMatrix &src_mat = *(src[i]); + int32 src_rows = src_mat.NumRows(), src_cols = src_mat.NumCols(); + if (src_rows != 0) { + tot_rows += src_rows; + if (num_cols == -1) num_cols = src_cols; + else if (num_cols != src_cols) + KALDI_ERR << "Appending rows of matrices with inconsistent num-cols: " + << num_cols << " vs. " << src_cols; + } + } + Matrix appended_mat(tot_rows, num_cols, kUndefined); + int32 row_offset = 0; + for (int32 i = 0; i < size; i++) { + const GeneralMatrix &src_mat = *(src[i]); + int32 src_rows = src_mat.NumRows(); + if (src_rows != 0) { + SubMatrix dest_submat(appended_mat, row_offset, src_rows, + 0, num_cols); + src_mat.CopyToMat(&dest_submat); + row_offset += src_rows; + } + } + KALDI_ASSERT(row_offset == tot_rows); + mat->SwapFullMatrix(&appended_mat); + } +} + +void FilterCompressedMatrixRows(const CompressedMatrix &in, + const std::vector &keep_rows, + Matrix *out) { + KALDI_ASSERT(keep_rows.size() == static_cast(in.NumRows())); + int32 num_kept_rows = 0; + std::vector::const_iterator iter = keep_rows.begin(), + end = keep_rows.end(); + for (; iter != end; ++iter) + if (*iter) + num_kept_rows++; + if (num_kept_rows == 0) + KALDI_ERR << "No kept rows"; + if (num_kept_rows == static_cast(keep_rows.size())) { + out->Resize(in.NumRows(), in.NumCols(), kUndefined); + in.CopyToMat(out); + return; + } + const BaseFloat heuristic = 0.33; + // should be > 0 and < 1.0. represents the performance hit we get from + // iterating row-wise versus column-wise in compressed-matrix uncompression. + + if (num_kept_rows > heuristic * in.NumRows()) { + // if quite a few of the the rows are kept, it may be more efficient + // to uncompress the entire compressed matrix, since per-column operation + // is more efficient. + Matrix full_mat(in); + FilterMatrixRows(full_mat, keep_rows, out); + } else { + out->Resize(num_kept_rows, in.NumCols(), kUndefined); + + iter = keep_rows.begin(); + int32 out_row = 0; + for (int32 in_row = 0; iter != end; ++iter, ++in_row) { + if (*iter) { + SubVector dest(*out, out_row); + in.CopyRowToVec(in_row, &dest); + out_row++; + } + } + KALDI_ASSERT(out_row == num_kept_rows); + } +} + +template +void FilterMatrixRows(const Matrix &in, + const std::vector &keep_rows, + Matrix *out) { + KALDI_ASSERT(keep_rows.size() == static_cast(in.NumRows())); + int32 num_kept_rows = 0; + std::vector::const_iterator iter = keep_rows.begin(), + end = keep_rows.end(); + for (; iter != end; ++iter) + if (*iter) + num_kept_rows++; + if (num_kept_rows == 0) + KALDI_ERR << "No kept rows"; + if (num_kept_rows == static_cast(keep_rows.size())) { + *out = in; + return; + } + out->Resize(num_kept_rows, in.NumCols(), kUndefined); + iter = keep_rows.begin(); + int32 out_row = 0; + for (int32 in_row = 0; iter != end; ++iter, ++in_row) { + if (*iter) { + SubVector src(in, in_row); + SubVector dest(*out, out_row); + dest.CopyFromVec(src); + out_row++; + } + } + KALDI_ASSERT(out_row == num_kept_rows); +} + +template +void FilterMatrixRows(const Matrix &in, + const std::vector &keep_rows, + Matrix *out); +template +void FilterMatrixRows(const Matrix &in, + const std::vector &keep_rows, + Matrix *out); + +template +void FilterSparseMatrixRows(const SparseMatrix &in, + const std::vector &keep_rows, + SparseMatrix *out) { + KALDI_ASSERT(keep_rows.size() == static_cast(in.NumRows())); + int32 num_kept_rows = 0; + std::vector::const_iterator iter = keep_rows.begin(), + end = keep_rows.end(); + for (; iter != end; ++iter) + if (*iter) + num_kept_rows++; + if (num_kept_rows == 0) + KALDI_ERR << "No kept rows"; + if (num_kept_rows == static_cast(keep_rows.size())) { + *out = in; + return; + } + out->Resize(num_kept_rows, in.NumCols(), kUndefined); + iter = keep_rows.begin(); + int32 out_row = 0; + for (int32 in_row = 0; iter != end; ++iter, ++in_row) { + if (*iter) { + out->SetRow(out_row, in.Row(in_row)); + out_row++; + } + } + KALDI_ASSERT(out_row == num_kept_rows); +} + +template +void FilterSparseMatrixRows(const SparseMatrix &in, + const std::vector &keep_rows, + SparseMatrix *out); +template +void FilterSparseMatrixRows(const SparseMatrix &in, + const std::vector &keep_rows, + SparseMatrix *out); + + +void FilterGeneralMatrixRows(const GeneralMatrix &in, + const std::vector &keep_rows, + GeneralMatrix *out) { + out->Clear(); + KALDI_ASSERT(keep_rows.size() == static_cast(in.NumRows())); + int32 num_kept_rows = 0; + std::vector::const_iterator iter = keep_rows.begin(), + end = keep_rows.end(); + for (; iter != end; ++iter) + if (*iter) + num_kept_rows++; + if (num_kept_rows == 0) + KALDI_ERR << "No kept rows"; + if (num_kept_rows == static_cast(keep_rows.size())) { + *out = in; + return; + } + switch (in.Type()) { + case kCompressedMatrix: { + const CompressedMatrix &cmat = in.GetCompressedMatrix(); + Matrix full_mat; + FilterCompressedMatrixRows(cmat, keep_rows, &full_mat); + out->SwapFullMatrix(&full_mat); + return; + } + case kSparseMatrix: { + const SparseMatrix &smat = in.GetSparseMatrix(); + SparseMatrix smat_out; + FilterSparseMatrixRows(smat, keep_rows, &smat_out); + out->SwapSparseMatrix(&smat_out); + return; + } + case kFullMatrix: { + const Matrix &full_mat = in.GetFullMatrix(); + Matrix full_mat_out; + FilterMatrixRows(full_mat, keep_rows, &full_mat_out); + out->SwapFullMatrix(&full_mat_out); + return; + } + default: + KALDI_ERR << "Invalid general-matrix type."; + } +} + +void GeneralMatrix::AddToMat(BaseFloat alpha, MatrixBase *mat, + MatrixTransposeType trans) const { + switch (this->Type()) { + case kFullMatrix: { + mat->AddMat(alpha, mat_, trans); + break; + } + case kSparseMatrix: { + smat_.AddToMat(alpha, mat, trans); + break; + } + case kCompressedMatrix: { + Matrix temp_mat(cmat_); + mat->AddMat(alpha, temp_mat, trans); + break; + } + default: + KALDI_ERR << "Invalid general-matrix type."; + } +} + +template +Real SparseVector::Max(int32 *index_out) const { + KALDI_ASSERT(dim_ > 0 && pairs_.size() <= static_cast(dim_)); + Real ans = -std::numeric_limits::infinity(); + int32 index = 0; + typename std::vector >::const_iterator + iter = pairs_.begin(), end = pairs_.end(); + for (; iter != end; ++iter) { + if (iter->second > ans) { + ans = iter->second; + index = iter->first; + } + } + if (ans >= 0 || pairs_.size() == dim_) { + // ans >= 0 will be the normal case. + // if pairs_.size() == dim_ then we need to return + // even a negative answer as there are no spaces (hence no unlisted zeros). + *index_out = index; + return ans; + } + // all the stored elements are < 0, but there are unlisted + // elements -> pick the first unlisted element. + // Note that this class requires that the indexes are sorted + // and unique. + index = 0; // "index" will always be the next index, that + // we haven't seen listed yet. + iter = pairs_.begin(); + for (; iter != end; ++iter) { + if (iter->first > index) { // index "index" is not listed. + *index_out = index; + return 0.0; + } else { + // index is the next potential gap in the indexes. + index = iter->first + 1; + } + } + // we can reach here if either pairs_.empty(), or + // pairs_ is nonempty but contains a sequence (0, 1, 2,...). + if (!pairs_.empty()) + index = pairs_.back().first + 1; + // else leave index at zero + KALDI_ASSERT(index < dim_); + *index_out = index; + return 0.0; +} + +template +SparseVector::SparseVector(const VectorBase &vec) { + MatrixIndexT dim = vec.Dim(); + dim_ = dim; + if (dim == 0) + return; + const Real *ptr = vec.Data(); + for (MatrixIndexT i = 0; i < dim; i++) { + Real val = ptr[i]; + if (val != 0.0) + pairs_.push_back(std::pair(i,val)); + } +} + +void GeneralMatrix::Swap(GeneralMatrix *other) { + mat_.Swap(&(other->mat_)); + cmat_.Swap(&(other->cmat_)); + smat_.Swap(&(other->smat_)); +} + + +void ExtractRowRangeWithPadding( + const GeneralMatrix &in, + int32 row_offset, + int32 num_rows, + GeneralMatrix *out) { + // make sure 'out' is empty to start with. + Matrix empty_mat; + *out = empty_mat; + if (num_rows == 0) return; + switch (in.Type()) { + case kFullMatrix: { + const Matrix &mat_in = in.GetFullMatrix(); + int32 num_rows_in = mat_in.NumRows(), num_cols = mat_in.NumCols(); + KALDI_ASSERT(num_rows_in > 0); // we can't extract >0 rows from an empty + // matrix. + Matrix mat_out(num_rows, num_cols, kUndefined); + for (int32 row = 0; row < num_rows; row++) { + int32 row_in = row + row_offset; + if (row_in < 0) row_in = 0; + else if (row_in >= num_rows_in) row_in = num_rows_in - 1; + SubVector vec_in(mat_in, row_in), + vec_out(mat_out, row); + vec_out.CopyFromVec(vec_in); + } + out->SwapFullMatrix(&mat_out); + break; + } + case kSparseMatrix: { + const SparseMatrix &smat_in = in.GetSparseMatrix(); + int32 num_rows_in = smat_in.NumRows(), + num_cols = smat_in.NumCols(); + KALDI_ASSERT(num_rows_in > 0); // we can't extract >0 rows from an empty + // matrix. + SparseMatrix smat_out(num_rows, num_cols); + for (int32 row = 0; row < num_rows; row++) { + int32 row_in = row + row_offset; + if (row_in < 0) row_in = 0; + else if (row_in >= num_rows_in) row_in = num_rows_in - 1; + smat_out.SetRow(row, smat_in.Row(row_in)); + } + out->SwapSparseMatrix(&smat_out); + break; + } + case kCompressedMatrix: { + const CompressedMatrix &cmat_in = in.GetCompressedMatrix(); + bool allow_padding = true; + CompressedMatrix cmat_out(cmat_in, row_offset, num_rows, + 0, cmat_in.NumCols(), allow_padding); + out->SwapCompressedMatrix(&cmat_out); + break; + } + default: + KALDI_ERR << "Bad matrix type."; + } +} + + + +template class SparseVector; +template class SparseVector; +template class SparseMatrix; +template class SparseMatrix; + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/sparse-matrix.h b/speechx/speechx/kaldi/matrix/sparse-matrix.h new file mode 100644 index 00000000..76f77f53 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/sparse-matrix.h @@ -0,0 +1,452 @@ +// matrix/sparse-matrix.h + +// Copyright 2015 Johns Hopkins University (author: Daniel Povey) +// 2015 Guoguo Chen +// 2017 Shiyin Kang + + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_SPARSE_MATRIX_H_ +#define KALDI_MATRIX_SPARSE_MATRIX_H_ 1 + +#include +#include + +#include "matrix/matrix-common.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/kaldi-vector.h" +#include "matrix/compressed-matrix.h" + +namespace kaldi { + + +/// \addtogroup matrix_group +/// @{ + +template +class SparseVector { + public: + MatrixIndexT Dim() const { return dim_; } + + Real Sum() const; + + template + void CopyElementsToVec(VectorBase *vec) const; + + // *vec += alpha * *this. + template + void AddToVec(Real alpha, + VectorBase *vec) const; + + template + void CopyFromSvec(const SparseVector &other); + + SparseVector &operator = (const SparseVector &other); + + SparseVector(const SparseVector &other) { *this = other; } + + void Swap(SparseVector *other); + + // Returns the maximum value in this row and outputs the index associated with + // it. This is not the index into the Data() pointer, it is the index into + // the vector it represents, i.e. the .first value in the pair. + // If this vector's Dim() is zero it is an error to call this function. + // If all the elements stored were negative and there underlying vector had + // zero indexes not listed in the elements, or if no elements are stored, it + // will return the first un-listed index, whose value (implicitly) is zero. + Real Max(int32 *index) const; + + /// Returns the number of nonzero elements. + MatrixIndexT NumElements() const { return pairs_.size(); } + + /// get an indexed element (0 <= i < NumElements()). + const std::pair &GetElement(MatrixIndexT i) const { + return pairs_[i]; + } + + // returns pointer to element data, or NULL if empty (use with NumElements()). + std::pair *Data(); + + // returns pointer to element data, or NULL if empty (use with NumElements()); + // const version + const std::pair *Data() const; + + /// Sets elements to zero with probability zero_prob, else normally + /// distributed. Useful in testing. + void SetRandn(BaseFloat zero_prob); + + SparseVector(): dim_(0) { } + + explicit SparseVector(MatrixIndexT dim): dim_(dim) { KALDI_ASSERT(dim >= 0); } + + // constructor from pairs; does not assume input pairs are sorted and uniq + SparseVector(MatrixIndexT dim, + const std::vector > &pairs); + + // constructor from a VectorBase that keeps only the nonzero elements of 'vec'. + explicit SparseVector(const VectorBase &vec); + + /// Resizes to this dimension. resize_type == kUndefined + /// behaves the same as kSetZero. + void Resize(MatrixIndexT dim, MatrixResizeType resize_type = kSetZero); + + void Write(std::ostream &os, bool binary) const; + + void Read(std::istream &os, bool binary); + + /// Scale all elements of sparse vector. + void Scale(Real alpha); + + private: + MatrixIndexT dim_; + // pairs of (row-index, value). Stored in sorted order with no duplicates. + // For now we use std::vector, but we could change this. + std::vector > pairs_; +}; + + +template +Real VecSvec(const VectorBase &vec, + const SparseVector &svec); + + + +template +class SparseMatrix { + public: + MatrixIndexT NumRows() const; + + MatrixIndexT NumCols() const; + + MatrixIndexT NumElements() const; + + Real Sum() const; + + Real FrobeniusNorm() const; + + + /// This constructor creates a SparseMatrix that just contains the nonzero + /// elements of 'mat'. + explicit SparseMatrix(const MatrixBase &mat); + + /// Copy to matrix. It must already have the correct size. + template + void CopyToMat(MatrixBase *other, + MatrixTransposeType t = kNoTrans) const; + + /// Copies the values of all the elements in SparseMatrix into a VectorBase + /// object. + void CopyElementsToVec(VectorBase *other) const; + + /// Copies data from another sparse matrix. + template + void CopyFromSmat(const SparseMatrix &other, + MatrixTransposeType trans = kNoTrans); + + /// Does *other = *other + alpha * *this. + void AddToMat(BaseFloat alpha, MatrixBase *other, + MatrixTransposeType t = kNoTrans) const; + + SparseMatrix &operator = (const SparseMatrix &other); + + SparseMatrix(const SparseMatrix &other, MatrixTransposeType trans = + kNoTrans) { + this->CopyFromSmat(other, trans); + } + + void Swap(SparseMatrix *other); + + // returns pointer to element data, or NULL if empty (use with NumElements()). + SparseVector *Data(); + + // returns pointer to element data, or NULL if empty (use with NumElements()); + // const version + const SparseVector *Data() const; + + // initializer from the type that elsewhere in Kaldi is referred to as type + // Posterior. indexed first by row-index; the pairs are (column-index, value), + // and the constructor does not require them to be sorted and uniq. + SparseMatrix( + int32 dim, + const std::vector > > &pairs); + + /// Sets up to a pseudo-randomly initialized matrix, with each element zero + /// with probability zero_prob and else normally distributed- mostly for + /// purposes of testing. + void SetRandn(BaseFloat zero_prob); + + void Write(std::ostream &os, bool binary) const; + + void Read(std::istream &os, bool binary); + + const SparseVector &Row(MatrixIndexT r) const; + + /// Sets row r to "vec"; makes sure it has the correct dimension. + void SetRow(int32 r, const SparseVector &vec); + + /// Select a subset of the rows of a SparseMatrix. + /// Sets *this to only the rows of 'smat_other' that are listed + /// in 'row_indexes'. + /// 'row_indexes' must satisfy 0 <= row_indexes[i] < smat_other.NumRows(). + void SelectRows(const std::vector &row_indexes, + const SparseMatrix &smat_other); + + + /// Sets *this to all the rows of *inputs appended together; this + /// function is destructive of the inputs. Requires, obviously, + /// that the inputs all have the same dimension (although some may be + /// empty). + void AppendSparseMatrixRows(std::vector > *inputs); + + SparseMatrix() { } + + SparseMatrix(int32 num_rows, int32 num_cols) { Resize(num_rows, num_cols); } + + /// Constructor from an array of indexes. + /// If trans == kNoTrans, construct a sparse matrix + /// with num-rows == indexes.Dim() and num-cols = 'dim'. + /// 'indexes' is expected to contain elements in the + /// range [0, dim - 1]. Each row 'i' of *this after + /// calling the constructor will contain a single + /// element at column-index indexes[i] with value 1.0. + /// + /// If trans == kTrans, the result will be the transpose + /// of the sparse matrix described above. + SparseMatrix(const std::vector &indexes, int32 dim, + MatrixTransposeType trans = kNoTrans); + + /// Constructor from an array of indexes and an array of + /// weights; requires indexes.Dim() == weights.Dim(). + /// If trans == kNoTrans, construct a sparse matrix + /// with num-rows == indexes.Dim() and num-cols = 'dim'. + /// 'indexes' is expected to contain elements in the + /// range [0, dim - 1]. Each row 'i' of *this after + /// calling the constructor will contain a single + /// element at column-index indexes[i] with value weights[i]. + /// If trans == kTrans, the result will be the transpose + /// of the sparse matrix described above. + SparseMatrix(const std::vector &indexes, + const VectorBase &weights, int32 dim, + MatrixTransposeType trans = kNoTrans); + + /// Resizes the matrix; analogous to Matrix::Resize(). resize_type == + /// kUndefined behaves the same as kSetZero. + void Resize(MatrixIndexT rows, MatrixIndexT cols, + MatrixResizeType resize_type = kSetZero); + + /// Scale all elements in sparse matrix. + void Scale(Real alpha); + + // Use the Matrix::CopyFromSmat() function to copy from this to Matrix. Also + // see Matrix::AddSmat(). There is not very extensive functionality for + // SparseMat just yet (e.g. no matrix multiply); we will add things as needed + // and as it seems necessary. + private: + // vector of SparseVectors, all of same dime (use an stl vector for now; this + // could change). + std::vector > rows_; +}; + + +template +Real TraceMatSmat(const MatrixBase &A, + const SparseMatrix &B, + MatrixTransposeType trans = kNoTrans); + + +enum GeneralMatrixType { + kFullMatrix, + kCompressedMatrix, + kSparseMatrix +}; + +/// This class is a wrapper that enables you to store a matrix +/// in one of three forms: either as a Matrix, or a CompressedMatrix, +/// or a SparseMatrix. It handles the I/O for you, i.e. you read +/// and write a single object type. It is useful for neural-net training +/// targets which might be sparse or not, and might be compressed or not. +class GeneralMatrix { + public: + /// Returns the type of the matrix: kSparseMatrix, kCompressedMatrix or + /// kFullMatrix. If this matrix is empty, returns kFullMatrix. + GeneralMatrixType Type() const; + + void Compress(); // If it was a full matrix, compresses, changing Type() to + // kCompressedMatrix; otherwise does nothing. + + void Uncompress(); // If it was a compressed matrix, uncompresses, changing + // Type() to kFullMatrix; otherwise does nothing. + + void Write(std::ostream &os, bool binary) const; + + + /// Note: if you write a compressed matrix in text form, it will be read as + /// a regular full matrix. + void Read(std::istream &is, bool binary); + + /// Returns the contents as a SparseMatrix. This will only work if + /// Type() returns kSparseMatrix, or NumRows() == 0; otherwise it will crash. + const SparseMatrix &GetSparseMatrix() const; + + /// Swaps the with the given SparseMatrix. This will only work if + /// Type() returns kSparseMatrix, or NumRows() == 0. + void SwapSparseMatrix(SparseMatrix *smat); + + /// Returns the contents as a compressed matrix. This will only work if + /// Type() returns kCompressedMatrix, or NumRows() == 0; otherwise it will + /// crash. + const CompressedMatrix &GetCompressedMatrix() const; + + /// Swaps the with the given CompressedMatrix. This will only work if + /// Type() returns kCompressedMatrix, or NumRows() == 0. + void SwapCompressedMatrix(CompressedMatrix *cmat); + + /// Returns the contents as a Matrix. This will only work if + /// Type() returns kFullMatrix, or NumRows() == 0; otherwise it will crash. + const Matrix& GetFullMatrix() const; + + /// Outputs the contents as a matrix. This will work regardless of + /// Type(). Sizes its output, unlike CopyToMat(). + void GetMatrix(Matrix *mat) const; + + /// Swaps the with the given Matrix. This will only work if + /// Type() returns kFullMatrix, or NumRows() == 0. + void SwapFullMatrix(Matrix *mat); + + /// Copies contents, regardless of type, to "mat", which must be correctly + /// sized. See also GetMatrix(), which will size its output for you. + void CopyToMat(MatrixBase *mat, + MatrixTransposeType trans = kNoTrans) const; + + /// Copies contents, regardless of type, to "cu_mat", which must be + /// correctly sized. Implemented in ../cudamatrix/cu-sparse-matrix.cc + void CopyToMat(CuMatrixBase *cu_mat, + MatrixTransposeType trans = kNoTrans) const; + + /// Adds alpha times *this to mat. + void AddToMat(BaseFloat alpha, MatrixBase *mat, + MatrixTransposeType trans = kNoTrans) const; + + /// Adds alpha times *this to cu_mat. + /// Implemented in ../cudamatrix/cu-sparse-matrix.cc + void AddToMat(BaseFloat alpha, CuMatrixBase *cu_mat, + MatrixTransposeType trans = kNoTrans) const; + + /// Scale each element of matrix by alpha. + void Scale(BaseFloat alpha); + + /// Assignment from regular matrix. + GeneralMatrix &operator= (const MatrixBase &mat); + + /// Assignment from compressed matrix. + GeneralMatrix &operator= (const CompressedMatrix &mat); + + /// Assignment from SparseMatrix + GeneralMatrix &operator= (const SparseMatrix &smat); + + MatrixIndexT NumRows() const; + + MatrixIndexT NumCols() const; + + explicit GeneralMatrix(const MatrixBase &mat) { *this = mat; } + + explicit GeneralMatrix(const CompressedMatrix &cmat) { *this = cmat; } + + explicit GeneralMatrix(const SparseMatrix &smat) { *this = smat; } + + GeneralMatrix() { } + // Assignment operator. + GeneralMatrix &operator =(const GeneralMatrix &other); + // Copy constructor + GeneralMatrix(const GeneralMatrix &other) { *this = other; } + // Sets to the empty matrix. + void Clear(); + // shallow swap + void Swap(GeneralMatrix *other); + private: + // We don't explicitly store the type of the matrix. Rather, we make + // sure that only one of the matrices is ever nonempty, and the Type() + // returns that one, or kFullMatrix if all are empty. + Matrix mat_; + CompressedMatrix cmat_; + SparseMatrix smat_; +}; + + +/// Appends all the matrix rows of a list of GeneralMatrixes, to get a single +/// GeneralMatrix. Preserves sparsity if all inputs were sparse (or empty). +/// Does not preserve compression, if inputs were compressed; you have to +/// re-compress manually, if that's what you need. +void AppendGeneralMatrixRows(const std::vector &src, + GeneralMatrix *mat); + + +/// Outputs a SparseMatrix containing only the rows r of "in" such that +/// keep_rows[r] == true. keep_rows.size() must equal in.NumRows(), and rows +/// must contain at least one "true" element. +template +void FilterSparseMatrixRows(const SparseMatrix &in, + const std::vector &keep_rows, + SparseMatrix *out); + +/// Outputs a Matrix containing only the rows r of "in" such that +/// keep_keep_rows[r] == true. keep_rows.size() must equal in.NumRows(), and +/// keep_rows must contain at least one "true" element. +template +void FilterMatrixRows(const Matrix &in, + const std::vector &keep_rows, + Matrix *out); + +/// Outputs a Matrix containing only the rows r of "in" such that +/// keep_rows[r] == true. keep_rows.size() must equal in.NumRows(), and rows +/// must contain at least one "true" element. +void FilterCompressedMatrixRows(const CompressedMatrix &in, + const std::vector &keep_rows, + Matrix *out); + + +/// Outputs a GeneralMatrix containing only the rows r of "in" such that +/// keep_rows[r] == true. keep_rows.size() must equal in.NumRows(), and +/// keep_rows must contain at least one "true" element. If in.Type() is +/// kCompressedMatrix, the result will not be compressed; otherwise, the type +/// is preserved. +void FilterGeneralMatrixRows(const GeneralMatrix &in, + const std::vector &keep_rows, + GeneralMatrix *out); + +/// This function extracts a row-range of a GeneralMatrix and writes +/// as a GeneralMatrix containing the same type of underlying +/// matrix. If the row-range is partly outside the row-range of 'in' +/// (i.e. if row_offset < 0 or row_offset + num_rows > in.NumRows()) +/// then it will pad with copies of the first and last row as +/// needed. +/// This is more efficient than un-compressing and +/// re-compressing the underlying CompressedMatrix, and causes +/// less accuracy loss due to re-compression (no loss in most cases). +void ExtractRowRangeWithPadding( + const GeneralMatrix &in, + int32 row_offset, + int32 num_rows, + GeneralMatrix *out); + + +/// @} end of \addtogroup matrix_group + + +} // namespace kaldi + +#endif // KALDI_MATRIX_SPARSE_MATRIX_H_ diff --git a/speechx/speechx/kaldi/matrix/srfft.cc b/speechx/speechx/kaldi/matrix/srfft.cc new file mode 100644 index 00000000..f6189496 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/srfft.cc @@ -0,0 +1,440 @@ +// matrix/srfft.cc + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// + +// This file includes a modified version of code originally published in Malvar, +// H., "Signal processing with lapped transforms, " Artech House, Inc., 1992. The +// current copyright holder of the original code, Henrique S. Malvar, has given +// his permission for the release of this modified version under the Apache +// License v2.0. + + +#include "matrix/srfft.h" +#include "matrix/matrix-functions.h" + +namespace kaldi { + + +template +SplitRadixComplexFft::SplitRadixComplexFft(MatrixIndexT N) { + if ( (N & (N-1)) != 0 || N <= 1) + KALDI_ERR << "SplitRadixComplexFft called with invalid number of points " + << N; + N_ = N; + logn_ = 0; + while (N > 1) { + N >>= 1; + logn_ ++; + } + ComputeTables(); +} + +template +SplitRadixComplexFft::SplitRadixComplexFft( + const SplitRadixComplexFft &other): + N_(other.N_), logn_(other.logn_) { + // This code duplicates tables from a previously computed object. + // Compare with the code in ComputeTables(). + MatrixIndexT lg2 = logn_ >> 1; + if (logn_ & 1) lg2++; + MatrixIndexT brseed_size = 1 << lg2; + brseed_ = new MatrixIndexT[brseed_size]; + std::memcpy(brseed_, other.brseed_, sizeof(MatrixIndexT) * brseed_size); + + if (logn_ < 4) { + tab_ = NULL; + } else { + tab_ = new Real*[logn_ - 3]; + for (MatrixIndexT i = logn_; i >= 4 ; i--) { + MatrixIndexT m = 1 << i, m2 = m / 2, m4 = m2 / 2; + MatrixIndexT this_array_size = 6 * (m4 - 2); + tab_[i-4] = new Real[this_array_size]; + std::memcpy(tab_[i-4], other.tab_[i-4], + sizeof(Real) * this_array_size); + } + } +} + +template +void SplitRadixComplexFft::ComputeTables() { + MatrixIndexT imax, lg2, i, j; + MatrixIndexT m, m2, m4, m8, nel, n; + Real *cn, *spcn, *smcn, *c3n, *spc3n, *smc3n; + Real ang, c, s; + + lg2 = logn_ >> 1; + if (logn_ & 1) lg2++; + brseed_ = new MatrixIndexT[1 << lg2]; + brseed_[0] = 0; + brseed_[1] = 1; + for (j = 2; j <= lg2; j++) { + imax = 1 << (j - 1); + for (i = 0; i < imax; i++) { + brseed_[i] <<= 1; + brseed_[i + imax] = brseed_[i] + 1; + } + } + + if (logn_ < 4) { + tab_ = NULL; + } else { + tab_ = new Real* [logn_-3]; + for (i = logn_; i>=4 ; i--) { + /* Compute a few constants */ + m = 1 << i; m2 = m / 2; m4 = m2 / 2; m8 = m4 /2; + + /* Allocate memory for tables */ + nel = m4 - 2; + + tab_[i-4] = new Real[6*nel]; + + /* Initialize pointers */ + cn = tab_[i-4]; spcn = cn + nel; smcn = spcn + nel; + c3n = smcn + nel; spc3n = c3n + nel; smc3n = spc3n + nel; + + /* Compute tables */ + for (n = 1; n < m4; n++) { + if (n == m8) continue; + ang = n * M_2PI / m; + c = std::cos(ang); s = std::sin(ang); + *cn++ = c; *spcn++ = - (s + c); *smcn++ = s - c; + ang = 3 * n * M_2PI / m; + c = std::cos(ang); s = std::sin(ang); + *c3n++ = c; *spc3n++ = - (s + c); *smc3n++ = s - c; + } + } + } +} + +template +SplitRadixComplexFft::~SplitRadixComplexFft() { + delete [] brseed_; + if (tab_ != NULL) { + for (MatrixIndexT i = 0; i < logn_-3; i++) + delete [] tab_[i]; + delete [] tab_; + } +} + +template +void SplitRadixComplexFft::Compute(Real *xr, Real *xi, bool forward) const { + if (!forward) { // reverse real and imaginary parts for complex FFT. + Real *tmp = xr; + xr = xi; + xi = tmp; + } + ComputeRecursive(xr, xi, logn_); + if (logn_ > 1) { + BitReversePermute(xr, logn_); + BitReversePermute(xi, logn_); + } +} + +template +void SplitRadixComplexFft::Compute(Real *x, bool forward, + std::vector *temp_buffer) const { + KALDI_ASSERT(temp_buffer != NULL); + if (temp_buffer->size() != N_) + temp_buffer->resize(N_); + Real *temp_ptr = &((*temp_buffer)[0]); + for (MatrixIndexT i = 0; i < N_; i++) { + x[i] = x[i * 2]; // put the real part in the first half of x. + temp_ptr[i] = x[i * 2 + 1]; // put the imaginary part in temp_buffer. + } + // copy the imaginary part back to the second half of x. + memcpy(static_cast(x + N_), + static_cast(temp_ptr), + sizeof(Real) * N_); + + Compute(x, x + N_, forward); + // Now change the format back to interleaved. + memcpy(static_cast(temp_ptr), + static_cast(x + N_), + sizeof(Real) * N_); + for (MatrixIndexT i = N_-1; i > 0; i--) { // don't include 0, + // in case MatrixIndexT is unsigned, the loop would not terminate. + // Treat it as a special case. + x[i*2] = x[i]; + x[i*2 + 1] = temp_ptr[i]; + } + x[1] = temp_ptr[0]; // special case of i = 0. +} + +template +void SplitRadixComplexFft::Compute(Real *x, bool forward) { + this->Compute(x, forward, &temp_buffer_); +} + +template +void SplitRadixComplexFft::BitReversePermute(Real *x, MatrixIndexT logn) const { + MatrixIndexT i, j, lg2, n; + MatrixIndexT off, fj, gno, *brp; + Real tmp, *xp, *xq; + + lg2 = logn >> 1; + n = 1 << lg2; + if (logn & 1) lg2++; + + /* Unshuffling loop */ + for (off = 1; off < n; off++) { + fj = n * brseed_[off]; i = off; j = fj; + tmp = x[i]; x[i] = x[j]; x[j] = tmp; + xp = &x[i]; + brp = &(brseed_[1]); + for (gno = 1; gno < brseed_[off]; gno++) { + xp += n; + j = fj + *brp++; + xq = x + j; + tmp = *xp; *xp = *xq; *xq = tmp; + } + } +} + + +template +void SplitRadixComplexFft::ComputeRecursive(Real *xr, Real *xi, MatrixIndexT logn) const { + + MatrixIndexT m, m2, m4, m8, nel, n; + Real *xr1, *xr2, *xi1, *xi2; + Real *cn = nullptr, *spcn = nullptr, *smcn = nullptr, *c3n = nullptr, + *spc3n = nullptr, *smc3n = nullptr; + Real tmp1, tmp2; + Real sqhalf = M_SQRT1_2; + + /* Check range of logn */ + if (logn < 0) + KALDI_ERR << "Error: logn is out of bounds in SRFFT"; + + /* Compute trivial cases */ + if (logn < 3) { + if (logn == 2) { /* length m = 4 */ + xr2 = xr + 2; + xi2 = xi + 2; + tmp1 = *xr + *xr2; + *xr2 = *xr - *xr2; + *xr = tmp1; + tmp1 = *xi + *xi2; + *xi2 = *xi - *xi2; + *xi = tmp1; + xr1 = xr + 1; + xi1 = xi + 1; + xr2++; + xi2++; + tmp1 = *xr1 + *xr2; + *xr2 = *xr1 - *xr2; + *xr1 = tmp1; + tmp1 = *xi1 + *xi2; + *xi2 = *xi1 - *xi2; + *xi1 = tmp1; + xr2 = xr + 1; + xi2 = xi + 1; + tmp1 = *xr + *xr2; + *xr2 = *xr - *xr2; + *xr = tmp1; + tmp1 = *xi + *xi2; + *xi2 = *xi - *xi2; + *xi = tmp1; + xr1 = xr + 2; + xi1 = xi + 2; + xr2 = xr + 3; + xi2 = xi + 3; + tmp1 = *xr1 + *xi2; + tmp2 = *xi1 + *xr2; + *xi1 = *xi1 - *xr2; + *xr2 = *xr1 - *xi2; + *xr1 = tmp1; + *xi2 = tmp2; + return; + } + else if (logn == 1) { /* length m = 2 */ + xr2 = xr + 1; + xi2 = xi + 1; + tmp1 = *xr + *xr2; + *xr2 = *xr - *xr2; + *xr = tmp1; + tmp1 = *xi + *xi2; + *xi2 = *xi - *xi2; + *xi = tmp1; + return; + } + else if (logn == 0) return; /* length m = 1 */ + } + + /* Compute a few constants */ + m = 1 << logn; m2 = m / 2; m4 = m2 / 2; m8 = m4 /2; + + + /* Step 1 */ + xr1 = xr; xr2 = xr1 + m2; + xi1 = xi; xi2 = xi1 + m2; + for (n = 0; n < m2; n++) { + tmp1 = *xr1 + *xr2; + *xr2 = *xr1 - *xr2; + xr2++; + *xr1++ = tmp1; + tmp2 = *xi1 + *xi2; + *xi2 = *xi1 - *xi2; + xi2++; + *xi1++ = tmp2; + } + + /* Step 2 */ + xr1 = xr + m2; xr2 = xr1 + m4; + xi1 = xi + m2; xi2 = xi1 + m4; + for (n = 0; n < m4; n++) { + tmp1 = *xr1 + *xi2; + tmp2 = *xi1 + *xr2; + *xi1 = *xi1 - *xr2; + xi1++; + *xr2++ = *xr1 - *xi2; + *xr1++ = tmp1; + *xi2++ = tmp2; + // xr1++; xr2++; xi1++; xi2++; + } + + /* Steps 3 & 4 */ + xr1 = xr + m2; xr2 = xr1 + m4; + xi1 = xi + m2; xi2 = xi1 + m4; + if (logn >= 4) { + nel = m4 - 2; + cn = tab_[logn-4]; spcn = cn + nel; smcn = spcn + nel; + c3n = smcn + nel; spc3n = c3n + nel; smc3n = spc3n + nel; + } + xr1++; xr2++; xi1++; xi2++; + // xr1++; xi1++; + for (n = 1; n < m4; n++) { + if (n == m8) { + tmp1 = sqhalf * (*xr1 + *xi1); + *xi1 = sqhalf * (*xi1 - *xr1); + *xr1 = tmp1; + tmp2 = sqhalf * (*xi2 - *xr2); + *xi2 = -sqhalf * (*xr2 + *xi2); + *xr2 = tmp2; + } else { + tmp2 = *cn++ * (*xr1 + *xi1); + tmp1 = *spcn++ * *xr1 + tmp2; + *xr1 = *smcn++ * *xi1 + tmp2; + *xi1 = tmp1; + tmp2 = *c3n++ * (*xr2 + *xi2); + tmp1 = *spc3n++ * *xr2 + tmp2; + *xr2 = *smc3n++ * *xi2 + tmp2; + *xi2 = tmp1; + } + xr1++; xr2++; xi1++; xi2++; + } + + /* Call ssrec again with half DFT length */ + ComputeRecursive(xr, xi, logn-1); + + /* Call ssrec again twice with one quarter DFT length. + Constants have to be recomputed, because they are static! */ + // m = 1 << logn; m2 = m / 2; + ComputeRecursive(xr + m2, xi + m2, logn - 2); + // m = 1 << logn; + m4 = 3 * (m / 4); + ComputeRecursive(xr + m4, xi + m4, logn - 2); +} + + +template +void SplitRadixRealFft::Compute(Real *data, bool forward) { + Compute(data, forward, &this->temp_buffer_); +} + + +// This code is mostly the same as the RealFft function. It would be +// possible to replace it with more efficient code from Rico's book. +template +void SplitRadixRealFft::Compute(Real *data, bool forward, + std::vector *temp_buffer) const { + MatrixIndexT N = N_, N2 = N/2; + KALDI_ASSERT(N%2 == 0); + if (forward) // call to base class + SplitRadixComplexFft::Compute(data, true, temp_buffer); + + Real rootN_re, rootN_im; // exp(-2pi/N), forward; exp(2pi/N), backward + int forward_sign = forward ? -1 : 1; + ComplexImExp(static_cast(M_2PI/N *forward_sign), &rootN_re, &rootN_im); + Real kN_re = -forward_sign, kN_im = 0.0; // exp(-2pik/N), forward; exp(-2pik/N), backward + // kN starts out as 1.0 for forward algorithm but -1.0 for backward. + for (MatrixIndexT k = 1; 2*k <= N2; k++) { + ComplexMul(rootN_re, rootN_im, &kN_re, &kN_im); + + Real Ck_re, Ck_im, Dk_re, Dk_im; + // C_k = 1/2 (B_k + B_{N/2 - k}^*) : + Ck_re = 0.5 * (data[2*k] + data[N - 2*k]); + Ck_im = 0.5 * (data[2*k + 1] - data[N - 2*k + 1]); + // re(D_k)= 1/2 (im(B_k) + im(B_{N/2-k})): + Dk_re = 0.5 * (data[2*k + 1] + data[N - 2*k + 1]); + // im(D_k) = -1/2 (re(B_k) - re(B_{N/2-k})) + Dk_im =-0.5 * (data[2*k] - data[N - 2*k]); + // A_k = C_k + 1^(k/N) D_k: + data[2*k] = Ck_re; // A_k <-- C_k + data[2*k+1] = Ck_im; + // now A_k += D_k 1^(k/N) + ComplexAddProduct(Dk_re, Dk_im, kN_re, kN_im, &(data[2*k]), &(data[2*k+1])); + + MatrixIndexT kdash = N2 - k; + if (kdash != k) { + // Next we handle the index k' = N/2 - k. This is necessary + // to do now, to avoid invalidating data that we will later need. + // The quantities C_{k'} and D_{k'} are just the conjugates of C_k + // and D_k, so the equations are simple modifications of the above, + // replacing Ck_im and Dk_im with their negatives. + data[2*kdash] = Ck_re; // A_k' <-- C_k' + data[2*kdash+1] = -Ck_im; + // now A_k' += D_k' 1^(k'/N) + // We use 1^(k'/N) = 1^((N/2 - k) / N) = 1^(1/2) 1^(-k/N) = -1 * (1^(k/N))^* + // so it's the same as 1^(k/N) but with the real part negated. + ComplexAddProduct(Dk_re, -Dk_im, -kN_re, kN_im, &(data[2*kdash]), &(data[2*kdash+1])); + } + } + + { // Now handle k = 0. + // In simple terms: after the complex fft, data[0] becomes the sum of real + // parts input[0], input[2]... and data[1] becomes the sum of imaginary + // pats input[1], input[3]... + // "zeroth" [A_0] is just the sum of input[0]+input[1]+input[2].. + // and "n2th" [A_{N/2}] is input[0]-input[1]+input[2]... . + Real zeroth = data[0] + data[1], + n2th = data[0] - data[1]; + data[0] = zeroth; + data[1] = n2th; + if (!forward) { + data[0] /= 2; + data[1] /= 2; + } + } + if (!forward) { // call to base class + SplitRadixComplexFft::Compute(data, false, temp_buffer); + for (MatrixIndexT i = 0; i < N; i++) + data[i] *= 2.0; + // This is so we get a factor of N increase, rather than N/2 which we would + // otherwise get from [ComplexFft, forward] + [ComplexFft, backward] in dimension N/2. + // It's for consistency with our normal FFT convensions. + } +} + +template class SplitRadixComplexFft; +template class SplitRadixComplexFft; +template class SplitRadixRealFft; +template class SplitRadixRealFft; + + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/srfft.h b/speechx/speechx/kaldi/matrix/srfft.h new file mode 100644 index 00000000..98ff782a --- /dev/null +++ b/speechx/speechx/kaldi/matrix/srfft.h @@ -0,0 +1,141 @@ +// matrix/srfft.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. +// 2014 Daniel Povey +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// This file includes a modified version of code originally published in Malvar, +// H., "Signal processing with lapped transforms, " Artech House, Inc., 1992. The +// current copyright holder of the original code, Henrique S. Malvar, has given +// his permission for the release of this modified version under the Apache +// License v2.0. + +#ifndef KALDI_MATRIX_SRFFT_H_ +#define KALDI_MATRIX_SRFFT_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// @addtogroup matrix_funcs_misc +/// @{ + + +// This class is based on code by Henrique (Rico) Malvar, from his book +// "Signal Processing with Lapped Transforms" (1992). Copied with +// permission, optimized by Go Vivace Inc., and converted into C++ by +// Microsoft Corporation +// This is a more efficient way of doing the complex FFT than ComplexFft +// (declared in matrix-functios.h), but it only works for powers of 2. +// Note: in multi-threaded code, you would need to have one of these objects per +// thread, because multiple calls to Compute in parallel would not work. +template +class SplitRadixComplexFft { + public: + typedef MatrixIndexT Integer; + + // N is the number of complex points (must be a power of two, or this + // will crash). Note that the constructor does some work so it's best to + // initialize the object once and do the computation many times. + SplitRadixComplexFft(Integer N); + + // Copy constructor + SplitRadixComplexFft(const SplitRadixComplexFft &other); + + // Does the FFT computation, given pointers to the real and + // imaginary parts. If "forward", do the forward FFT; else + // do the inverse FFT (without the 1/N factor). + // xr and xi are pointers to zero-based arrays of size N, + // containing the real and imaginary parts + // respectively. + void Compute(Real *xr, Real *xi, bool forward) const; + + // This version of Compute takes a single array of size N*2, + // containing [ r0 im0 r1 im1 ... ]. Otherwise its behavior is the + // same as the version above. + void Compute(Real *x, bool forward); + + + // This version of Compute is const; it operates on an array of size N*2 + // containing [ r0 im0 r1 im1 ... ], but it uses the argument "temp_buffer" as + // temporary storage instead of a class-member variable. It will allocate it if + // needed. + void Compute(Real *x, bool forward, std::vector *temp_buffer) const; + + ~SplitRadixComplexFft(); + + protected: + // temp_buffer_ is allocated only if someone calls Compute with only one Real* + // argument and we need a temporary buffer while creating interleaved data. + std::vector temp_buffer_; + private: + void ComputeTables(); + void ComputeRecursive(Real *xr, Real *xi, Integer logn) const; + void BitReversePermute(Real *x, Integer logn) const; + + Integer N_; + Integer logn_; // log(N) + + Integer *brseed_; + // brseed is Evans' seed table, ref: (Ref: D. M. W. + // Evans, "An improved digit-reversal permutation algorithm ...", + // IEEE Trans. ASSP, Aug. 1987, pp. 1120-1125). + Real **tab_; // Tables of butterfly coefficients. + + // Disallow assignment. + SplitRadixComplexFft &operator =(const SplitRadixComplexFft &other); +}; + +template +class SplitRadixRealFft: private SplitRadixComplexFft { + public: + SplitRadixRealFft(MatrixIndexT N): // will fail unless N>=4 and N is a power of 2. + SplitRadixComplexFft (N/2), N_(N) { } + + // Copy constructor + SplitRadixRealFft(const SplitRadixRealFft &other): + SplitRadixComplexFft(other), N_(other.N_) { } + + /// If forward == true, this function transforms from a sequence of N real points to its complex fourier + /// transform; otherwise it goes in the reverse direction. If you call it + /// in the forward and then reverse direction and multiply by 1.0/N, you + /// will get back the original data. + /// The interpretation of the complex-FFT data is as follows: the array + /// is a sequence of complex numbers C_n of length N/2 with (real, im) format, + /// i.e. [real0, real_{N/2}, real1, im1, real2, im2, real3, im3, ...]. + void Compute(Real *x, bool forward); + + + /// This is as the other Compute() function, but it is a const version that + /// uses a user-supplied buffer. + void Compute(Real *x, bool forward, std::vector *temp_buffer) const; + + private: + // Disallow assignment. + SplitRadixRealFft &operator =(const SplitRadixRealFft &other); + int N_; +}; + + +/// @} end of "addtogroup matrix_funcs_misc" + +} // end namespace kaldi + + +#endif + diff --git a/speechx/speechx/kaldi/matrix/tp-matrix.cc b/speechx/speechx/kaldi/matrix/tp-matrix.cc new file mode 100644 index 00000000..6e34dc64 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/tp-matrix.cc @@ -0,0 +1,145 @@ +// matrix/tp-matrix.cc + +// Copyright 2009-2011 Ondrej Glembek; Lukas Burget; Microsoft Corporation +// Saarland University; Yanmin Qian; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "matrix/tp-matrix.h" +#include "matrix/sp-matrix.h" +#include "matrix/kaldi-matrix.h" +#include "matrix/cblas-wrappers.h" + + +namespace kaldi { + +#ifndef HAVE_ATLAS +template +void TpMatrix::Invert() { + // these are CLAPACK types + KaldiBlasInt result; + KaldiBlasInt rows = static_cast(this->num_rows_); + + // clapack call + // NOTE: Even though "U" is for upper, lapack assumes column-wise storage + // of the data. We have a row-wise storage, therefore, we need to "invert" + clapack_Xtptri(&rows, this->data_, &result); + + if (result < 0) { + KALDI_ERR << "Call to CLAPACK stptri_ function failed"; + } else if (result > 0) { + KALDI_ERR << "Matrix is singular"; + } +} +#else +template +void TpMatrix::Invert() { + // ATLAS doesn't implement triangular matrix inversion in packed + // format, so we temporarily put in non-packed format. + Matrix tmp(*this); + int rows = static_cast(this->num_rows_); + + // ATLAS call. It's really row-major ordering and a lower triangular matrix, + // but there is some weirdness with Fortran-style indexing that we need to + // take account of, so everything gets swapped. + int result = clapack_Xtrtri( rows, tmp.Data(), tmp.Stride()); + // Let's hope ATLAS has the same return value conventions as clapack. + // I couldn't find any documentation online. + if (result < 0) { + KALDI_ERR << "Call to ATLAS strtri function failed"; + } else if (result > 0) { + KALDI_ERR << "Matrix is singular"; + } + (*this).CopyFromMat(tmp); +} +#endif + +template +Real TpMatrix::Determinant() { + double det = 1.0; + for (MatrixIndexT i = 0; iNumRows(); i++) { + det *= (*this)(i, i); + } + return static_cast(det); +} + + +template +void TpMatrix::Swap(TpMatrix *other) { + std::swap(this->data_, other->data_); + std::swap(this->num_rows_, other->num_rows_); +} + + +template +void TpMatrix::Cholesky(const SpMatrix &orig) { + KALDI_ASSERT(orig.NumRows() == this->NumRows()); + MatrixIndexT n = this->NumRows(); + this->SetZero(); + Real *data = this->data_, *jdata = data; // start of j'th row of matrix. + const Real *orig_jdata = orig.Data(); // start of j'th row of matrix. + for (MatrixIndexT j = 0; j < n; j++, jdata += j, orig_jdata += j) { + Real *kdata = data; // start of k'th row of matrix. + Real d(0.0); + for (MatrixIndexT k = 0; k < j; k++, kdata += k) { + Real s = cblas_Xdot(k, kdata, 1, jdata, 1); + // (*this)(j, k) = s = (orig(j, k) - s)/(*this)(k, k); + jdata[k] = s = (orig_jdata[k] - s)/kdata[k]; + d = d + s*s; + } + // d = orig(j, j) - d; + d = orig_jdata[j] - d; + + if (d >= 0.0) { + // (*this)(j, j) = std::sqrt(d); + jdata[j] = std::sqrt(d); + } else { + KALDI_ERR << "Cholesky decomposition failed. Maybe matrix " + "is not positive definite."; + } + } +} + +template +void TpMatrix::CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans) { + if (Trans == kNoTrans) { + KALDI_ASSERT(this->NumRows() == M.NumRows() && M.NumRows() == M.NumCols()); + MatrixIndexT D = this->NumRows(); + const Real *in_i = M.Data(); + MatrixIndexT stride = M.Stride(); + Real *out_i = this->data_; + for (MatrixIndexT i = 0; i < D; i++, in_i += stride, out_i += i) + for (MatrixIndexT j = 0; j <= i; j++) + out_i[j] = in_i[j]; + } else { + KALDI_ASSERT(this->NumRows() == M.NumRows() && M.NumRows() == M.NumCols()); + MatrixIndexT D = this->NumRows(); + const Real *in_i = M.Data(); + MatrixIndexT stride = M.Stride(); + Real *out_i = this->data_; + for (MatrixIndexT i = 0; i < D; i++, in_i++, out_i += i) { + for (MatrixIndexT j = 0; j <= i; j++) + out_i[j] = in_i[stride*j]; + } + } +} + + +template class TpMatrix; +template class TpMatrix; + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/matrix/tp-matrix.h b/speechx/speechx/kaldi/matrix/tp-matrix.h new file mode 100644 index 00000000..e3b08701 --- /dev/null +++ b/speechx/speechx/kaldi/matrix/tp-matrix.h @@ -0,0 +1,134 @@ +// matrix/tp-matrix.h + +// Copyright 2009-2011 Ondrej Glembek; Lukas Burget; Microsoft Corporation; +// Saarland University; Yanmin Qian; Haihua Xu +// 2013 Johns Hopkins Universith (author: Daniel Povey) + + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_TP_MATRIX_H_ +#define KALDI_MATRIX_TP_MATRIX_H_ + + +#include "matrix/packed-matrix.h" + +namespace kaldi { +/// \addtogroup matrix_group +/// @{ + +template class TpMatrix; + +/// @brief Packed symetric matrix class +template +class TpMatrix : public PackedMatrix { + friend class CuTpMatrix; + friend class CuTpMatrix; + public: + TpMatrix() : PackedMatrix() {} + explicit TpMatrix(MatrixIndexT r, MatrixResizeType resize_type = kSetZero) + : PackedMatrix(r, resize_type) {} + TpMatrix(const TpMatrix& orig) : PackedMatrix(orig) {} + + /// Copy constructor from CUDA TpMatrix + /// This is defined in ../cudamatrix/cu-tp-matrix.cc + explicit TpMatrix(const CuTpMatrix &cu); + + + template explicit TpMatrix(const TpMatrix& orig) + : PackedMatrix(orig) {} + + Real operator() (MatrixIndexT r, MatrixIndexT c) const { + if (static_cast(c) > + static_cast(r)) { + KALDI_ASSERT(static_cast(c) < + static_cast(this->num_rows_)); + return 0; + } + KALDI_ASSERT(static_cast(r) < + static_cast(this->num_rows_)); + // c<=r now so don't have to check c. + return *(this->data_ + (r*(r+1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + + Real &operator() (MatrixIndexT r, MatrixIndexT c) { + KALDI_ASSERT(static_cast(r) < + static_cast(this->num_rows_)); + KALDI_ASSERT(static_cast(c) <= + static_cast(r) && + "you cannot access the upper triangle of TpMatrix using " + "a non-const matrix object."); + return *(this->data_ + (r*(r+1)) / 2 + c); + // Duplicating code from PackedMatrix.h + } + // Note: Cholesky may throw KaldiFatalError. + void Cholesky(const SpMatrix& orig); + + void Invert(); + + // Inverts in double precision. + void InvertDouble() { + TpMatrix dmat(*this); + dmat.Invert(); + (*this).CopyFromTp(dmat); + } + + /// Shallow swap + void Swap(TpMatrix *other); + + /// Returns the determinant of the matrix (product of diagonals) + Real Determinant(); + + /// CopyFromMat copies the lower triangle of M into *this + /// (or the upper triangle, if Trans == kTrans). + void CopyFromMat(const MatrixBase &M, + MatrixTransposeType Trans = kNoTrans); + + /// This is implemented in ../cudamatrix/cu-tp-matrix.cc + void CopyFromMat(const CuTpMatrix &other); + + /// CopyFromTp copies another triangular matrix into this one. + void CopyFromTp(const TpMatrix &other) { + PackedMatrix::CopyFromPacked(other); + } + + template void CopyFromTp(const TpMatrix &other) { + PackedMatrix::CopyFromPacked(other); + } + + /// AddTp does *this += alpha * M. + void AddTp(const Real alpha, const TpMatrix &M) { + this->AddPacked(alpha, M); + } + + TpMatrix& operator=(const TpMatrix &other) { + PackedMatrix::operator=(other); + return *this; + } + + using PackedMatrix::Scale; + + void Resize(MatrixIndexT nRows, MatrixResizeType resize_type = kSetZero) { + PackedMatrix::Resize(nRows, resize_type); + } +}; + +/// @} end of "addtogroup matrix_group". + +} // namespace kaldi + + +#endif diff --git a/speechx/speechx/kaldi/util/CMakeLists.txt b/speechx/speechx/kaldi/util/CMakeLists.txt new file mode 100644 index 00000000..1ab26df3 --- /dev/null +++ b/speechx/speechx/kaldi/util/CMakeLists.txt @@ -0,0 +1,12 @@ +add_library(kaldi-util + kaldi-holder.cc + kaldi-io.cc + kaldi-semaphore.cc + kaldi-table.cc + kaldi-thread.cc + parse-options.cc + simple-io-funcs.cc + simple-options.cc + text-utils.cc +) +target_link_libraries(kaldi-util PUBLIC kaldi-base kaldi-matrix) \ No newline at end of file diff --git a/speechx/speechx/kaldi/util/basic-filebuf.h b/speechx/speechx/kaldi/util/basic-filebuf.h new file mode 100644 index 00000000..51cf12f4 --- /dev/null +++ b/speechx/speechx/kaldi/util/basic-filebuf.h @@ -0,0 +1,994 @@ +/////////////////////////////////////////////////////////////////////////////// +// This is a modified version of the std::basic_filebuf from libc++ +// (http://libcxx.llvm.org/). +// It allows one to create basic_filebuf from an existing FILE* handle or file +// descriptor. +// +// This file is dual licensed under the MIT and the University of Illinois Open +// Source License licenses. See LICENSE.TXT for details (included at the +// bottom). +/////////////////////////////////////////////////////////////////////////////// +#ifndef KALDI_UTIL_BASIC_FILEBUF_H_ +#define KALDI_UTIL_BASIC_FILEBUF_H_ + +/////////////////////////////////////////////////////////////////////////////// +#include +#include +#include +#include +#include + +/////////////////////////////////////////////////////////////////////////////// +namespace kaldi { +/////////////////////////////////////////////////////////////////////////////// +template > +class basic_filebuf : public std::basic_streambuf { + public: + typedef CharT char_type; + typedef Traits traits_type; + typedef typename traits_type::int_type int_type; + typedef typename traits_type::pos_type pos_type; + typedef typename traits_type::off_type off_type; + typedef typename traits_type::state_type state_type; + + basic_filebuf(); + basic_filebuf(basic_filebuf&& rhs); + virtual ~basic_filebuf(); + + basic_filebuf& operator=(basic_filebuf&& rhs); + void swap(basic_filebuf& rhs); + + bool is_open() const; + basic_filebuf* open(const char* s, std::ios_base::openmode mode); + basic_filebuf* open(const std::string& s, std::ios_base::openmode mode); + basic_filebuf* open(int fd, std::ios_base::openmode mode); + basic_filebuf* open(FILE* f, std::ios_base::openmode mode); + basic_filebuf* close(); + + FILE* file() { return this->_M_file; } + int fd() { return fileno(this->_M_file); } + + protected: + int_type underflow() override; + int_type pbackfail(int_type c = traits_type::eof()) override; + int_type overflow(int_type c = traits_type::eof()) override; + std::basic_streambuf* + setbuf(char_type* s, std::streamsize n) override; + pos_type seekoff(off_type off, std::ios_base::seekdir way, + std::ios_base::openmode wch = + std::ios_base::in | std::ios_base::out) override; + pos_type seekpos(pos_type sp, + std::ios_base::openmode wch = + std::ios_base::in | std::ios_base::out) override; + int sync() override; + void imbue(const std::locale& loc) override; + + protected: + char* _M_extbuf; + const char* _M_extbufnext; + const char* _M_extbufend; + char _M_extbuf_min[8]; + size_t _M_ebs; + char_type* _M_intbuf; + size_t _M_ibs; + FILE* _M_file; + const std::codecvt* _M_cv; + state_type _M_st; + state_type _M_st_last; + std::ios_base::openmode _M_om; + std::ios_base::openmode _M_cm; + bool _M_owns_eb; + bool _M_owns_ib; + bool _M_always_noconv; + + const char* _M_get_mode(std::ios_base::openmode mode); + bool _M_read_mode(); + void _M_write_mode(); +}; + +/////////////////////////////////////////////////////////////////////////////// +template +basic_filebuf::basic_filebuf() + : _M_extbuf(nullptr), + _M_extbufnext(nullptr), + _M_extbufend(nullptr), + _M_ebs(0), + _M_intbuf(nullptr), + _M_ibs(0), + _M_file(nullptr), + _M_cv(nullptr), + _M_st(), + _M_st_last(), + _M_om(std::ios_base::openmode(0)), + _M_cm(std::ios_base::openmode(0)), + _M_owns_eb(false), + _M_owns_ib(false), + _M_always_noconv(false) { + if (std::has_facet > + (this->getloc())) { + _M_cv = &std::use_facet > + (this->getloc()); + _M_always_noconv = _M_cv->always_noconv(); + } + setbuf(0, 4096); +} + +/////////////////////////////////////////////////////////////////////////////// +template +basic_filebuf::basic_filebuf(basic_filebuf&& rhs) + : std::basic_streambuf(rhs) { + if (rhs._M_extbuf == rhs._M_extbuf_min) { + _M_extbuf = _M_extbuf_min; + _M_extbufnext = _M_extbuf + (rhs._M_extbufnext - rhs._M_extbuf); + _M_extbufend = _M_extbuf + (rhs._M_extbufend - rhs._M_extbuf); + } else { + _M_extbuf = rhs._M_extbuf; + _M_extbufnext = rhs._M_extbufnext; + _M_extbufend = rhs._M_extbufend; + } + _M_ebs = rhs._M_ebs; + _M_intbuf = rhs._M_intbuf; + _M_ibs = rhs._M_ibs; + _M_file = rhs._M_file; + _M_cv = rhs._M_cv; + _M_st = rhs._M_st; + _M_st_last = rhs._M_st_last; + _M_om = rhs._M_om; + _M_cm = rhs._M_cm; + _M_owns_eb = rhs._M_owns_eb; + _M_owns_ib = rhs._M_owns_ib; + _M_always_noconv = rhs._M_always_noconv; + if (rhs.pbase()) { + if (rhs.pbase() == rhs._M_intbuf) + this->setp(_M_intbuf, _M_intbuf + (rhs. epptr() - rhs.pbase())); + else + this->setp(reinterpret_cast(_M_extbuf), + reinterpret_cast(_M_extbuf) + + (rhs. epptr() - rhs.pbase())); + this->pbump(rhs. pptr() - rhs.pbase()); + } else if (rhs.eback()) { + if (rhs.eback() == rhs._M_intbuf) + this->setg(_M_intbuf, _M_intbuf + (rhs.gptr() - rhs.eback()), + _M_intbuf + (rhs.egptr() - rhs.eback())); + else + this->setg(reinterpret_cast(_M_extbuf), + reinterpret_cast(_M_extbuf) + + (rhs.gptr() - rhs.eback()), + reinterpret_cast(_M_extbuf) + + (rhs.egptr() - rhs.eback())); + } + rhs._M_extbuf = nullptr; + rhs._M_extbufnext = nullptr; + rhs._M_extbufend = nullptr; + rhs._M_ebs = 0; + rhs._M_intbuf = nullptr; + rhs._M_ibs = 0; + rhs._M_file = nullptr; + rhs._M_st = state_type(); + rhs._M_st_last = state_type(); + rhs._M_om = std::ios_base::openmode(0); + rhs._M_cm = std::ios_base::openmode(0); + rhs._M_owns_eb = false; + rhs._M_owns_ib = false; + rhs.setg(0, 0, 0); + rhs.setp(0, 0); +} + +/////////////////////////////////////////////////////////////////////////////// +template +inline +basic_filebuf& +basic_filebuf::operator=(basic_filebuf&& rhs) { + close(); + swap(rhs); + return *this; +} + +/////////////////////////////////////////////////////////////////////////////// +template +basic_filebuf::~basic_filebuf() { + // try + // { + // close(); + // } + // catch (...) + // { + // } + if (_M_owns_eb) + delete [] _M_extbuf; + if (_M_owns_ib) + delete [] _M_intbuf; +} + +/////////////////////////////////////////////////////////////////////////////// +template +void +basic_filebuf::swap(basic_filebuf& rhs) { + std::basic_streambuf::swap(rhs); + if (_M_extbuf != _M_extbuf_min && rhs._M_extbuf != rhs._M_extbuf_min) { + std::swap(_M_extbuf, rhs._M_extbuf); + std::swap(_M_extbufnext, rhs._M_extbufnext); + std::swap(_M_extbufend, rhs._M_extbufend); + } else { + ptrdiff_t ln = _M_extbufnext - _M_extbuf; + ptrdiff_t le = _M_extbufend - _M_extbuf; + ptrdiff_t rn = rhs._M_extbufnext - rhs._M_extbuf; + ptrdiff_t re = rhs._M_extbufend - rhs._M_extbuf; + if (_M_extbuf == _M_extbuf_min && rhs._M_extbuf != rhs._M_extbuf_min) { + _M_extbuf = rhs._M_extbuf; + rhs._M_extbuf = rhs._M_extbuf_min; + } else if (_M_extbuf != _M_extbuf_min && + rhs._M_extbuf == rhs._M_extbuf_min) { + rhs._M_extbuf = _M_extbuf; + _M_extbuf = _M_extbuf_min; + } + _M_extbufnext = _M_extbuf + rn; + _M_extbufend = _M_extbuf + re; + rhs._M_extbufnext = rhs._M_extbuf + ln; + rhs._M_extbufend = rhs._M_extbuf + le; + } + std::swap(_M_ebs, rhs._M_ebs); + std::swap(_M_intbuf, rhs._M_intbuf); + std::swap(_M_ibs, rhs._M_ibs); + std::swap(_M_file, rhs._M_file); + std::swap(_M_cv, rhs._M_cv); + std::swap(_M_st, rhs._M_st); + std::swap(_M_st_last, rhs._M_st_last); + std::swap(_M_om, rhs._M_om); + std::swap(_M_cm, rhs._M_cm); + std::swap(_M_owns_eb, rhs._M_owns_eb); + std::swap(_M_owns_ib, rhs._M_owns_ib); + std::swap(_M_always_noconv, rhs._M_always_noconv); + if (this->eback() == reinterpret_cast(rhs._M_extbuf_min)) { + ptrdiff_t n = this->gptr() - this->eback(); + ptrdiff_t e = this->egptr() - this->eback(); + this->setg(reinterpret_cast(_M_extbuf_min), + reinterpret_cast(_M_extbuf_min) + n, + reinterpret_cast(_M_extbuf_min) + e); + } else if (this->pbase() == + reinterpret_cast(rhs._M_extbuf_min)) { + ptrdiff_t n = this->pptr() - this->pbase(); + ptrdiff_t e = this->epptr() - this->pbase(); + this->setp(reinterpret_cast(_M_extbuf_min), + reinterpret_cast(_M_extbuf_min) + e); + this->pbump(n); + } + if (rhs.eback() == reinterpret_cast(_M_extbuf_min)) { + ptrdiff_t n = rhs.gptr() - rhs.eback(); + ptrdiff_t e = rhs.egptr() - rhs.eback(); + rhs.setg(reinterpret_cast(rhs._M_extbuf_min), + reinterpret_cast(rhs._M_extbuf_min) + n, + reinterpret_cast(rhs._M_extbuf_min) + e); + } else if (rhs.pbase() == reinterpret_cast(_M_extbuf_min)) { + ptrdiff_t n = rhs.pptr() - rhs.pbase(); + ptrdiff_t e = rhs.epptr() - rhs.pbase(); + rhs.setp(reinterpret_cast(rhs._M_extbuf_min), + reinterpret_cast(rhs._M_extbuf_min) + e); + rhs.pbump(n); + } +} + +/////////////////////////////////////////////////////////////////////////////// +template +inline +void +swap(basic_filebuf& x, basic_filebuf& y) { + x.swap(y); +} + +/////////////////////////////////////////////////////////////////////////////// +template +inline +bool +basic_filebuf::is_open() const { + return _M_file != nullptr; +} + +/////////////////////////////////////////////////////////////////////////////// +template +const char* basic_filebuf:: +_M_get_mode(std::ios_base::openmode mode) { + switch ((mode & ~std::ios_base::ate) | 0) { + case std::ios_base::out: + case std::ios_base::out | std::ios_base::trunc: + return "w"; + case std::ios_base::out | std::ios_base::app: + case std::ios_base::app: + return "a"; + break; + case std::ios_base::in: + return "r"; + case std::ios_base::in | std::ios_base::out: + return "r+"; + case std::ios_base::in | std::ios_base::out | std::ios_base::trunc: + return "w+"; + case std::ios_base::in | std::ios_base::out | std::ios_base::app: + case std::ios_base::in | std::ios_base::app: + return "a+"; + case std::ios_base::out | std::ios_base::binary: + case std::ios_base::out | std::ios_base::trunc | std::ios_base::binary: + return "wb"; + case std::ios_base::out | std::ios_base::app | std::ios_base::binary: + case std::ios_base::app | std::ios_base::binary: + return "ab"; + case std::ios_base::in | std::ios_base::binary: + return "rb"; + case std::ios_base::in | std::ios_base::out | std::ios_base::binary: + return "r+b"; + case std::ios_base::in | std::ios_base::out | std::ios_base::trunc | + std::ios_base::binary: + return "w+b"; + case std::ios_base::in | std::ios_base::out | std::ios_base::app | + std::ios_base::binary: + case std::ios_base::in | std::ios_base::app | std::ios_base::binary: + return "a+b"; + default: + return nullptr; + } +} + +/////////////////////////////////////////////////////////////////////////////// +template +basic_filebuf* +basic_filebuf:: +open(const char* s, std::ios_base::openmode mode) { + basic_filebuf* rt = nullptr; + if (_M_file == nullptr) { + const char* md= _M_get_mode(mode); + if (md) { + _M_file = fopen(s, md); + if (_M_file) { + rt = this; + _M_om = mode; + if (mode & std::ios_base::ate) { + if (fseek(_M_file, 0, SEEK_END)) { + fclose(_M_file); + _M_file = nullptr; + rt = nullptr; + } + } + } + } + } + return rt; +} + +/////////////////////////////////////////////////////////////////////////////// +template +inline +basic_filebuf* +basic_filebuf::open(const std::string& s, + std::ios_base::openmode mode) { + return open(s.c_str(), mode); +} + +/////////////////////////////////////////////////////////////////////////////// +template +basic_filebuf* +basic_filebuf::open(int fd, std::ios_base::openmode mode) { + const char* md= this->_M_get_mode(mode); + if (md) { + this->_M_file= fdopen(fd, md); + this->_M_om = mode; + return this; + } else { + return nullptr; + } +} + +/////////////////////////////////////////////////////////////////////////////// +template +basic_filebuf* +basic_filebuf::open(FILE* f, std::ios_base::openmode mode) { + this->_M_file = f; + this->_M_om = mode; + return this; +} + +/////////////////////////////////////////////////////////////////////////////// +template +basic_filebuf* +basic_filebuf::close() { + basic_filebuf* rt = nullptr; + if (_M_file) { + rt = this; + std::unique_ptr h(_M_file, fclose); + if (sync()) + rt = nullptr; + if (fclose(h.release()) == 0) + _M_file = nullptr; + else + rt = nullptr; + } + return rt; +} + +/////////////////////////////////////////////////////////////////////////////// +template +typename basic_filebuf::int_type +basic_filebuf::underflow() { + if (_M_file == nullptr) + return traits_type::eof(); + bool initial = _M_read_mode(); + char_type buf; + if (this->gptr() == nullptr) + this->setg(&buf, &buf+1, &buf+1); + const size_t unget_sz = initial ? 0 : std:: + min((this->egptr() - this->eback()) / 2, 4); + int_type c = traits_type::eof(); + if (this->gptr() == this->egptr()) { + memmove(this->eback(), this->egptr() - unget_sz, + unget_sz * sizeof(char_type)); + if (_M_always_noconv) { + size_t nmemb = static_cast + (this->egptr() - this->eback() - unget_sz); + nmemb = fread(this->eback() + unget_sz, 1, nmemb, _M_file); + if (nmemb != 0) { + this->setg(this->eback(), + this->eback() + unget_sz, + this->eback() + unget_sz + nmemb); + c = traits_type::to_int_type(*this->gptr()); + } + } else { + memmove(_M_extbuf, _M_extbufnext, _M_extbufend - _M_extbufnext); + _M_extbufnext = _M_extbuf + (_M_extbufend - _M_extbufnext); + _M_extbufend = _M_extbuf + + (_M_extbuf == _M_extbuf_min ? sizeof(_M_extbuf_min) : _M_ebs); + size_t nmemb = std::min(static_cast(_M_ibs - unget_sz), + static_cast + (_M_extbufend - _M_extbufnext)); + std::codecvt_base::result r; + _M_st_last = _M_st; + size_t nr = fread( + reinterpret_cast(const_cast(_M_extbufnext)), + 1, nmemb, _M_file); + if (nr != 0) { + if (!_M_cv) + throw std::bad_cast(); + _M_extbufend = _M_extbufnext + nr; + char_type* inext; + r = _M_cv->in(_M_st, _M_extbuf, _M_extbufend, _M_extbufnext, + this->eback() + unget_sz, + this->eback() + _M_ibs, inext); + if (r == std::codecvt_base::noconv) { + this->setg(reinterpret_cast(_M_extbuf), + reinterpret_cast(_M_extbuf), + const_cast(_M_extbufend)); + c = traits_type::to_int_type(*this->gptr()); + } else if (inext != this->eback() + unget_sz) { + this->setg(this->eback(), this->eback() + unget_sz, inext); + c = traits_type::to_int_type(*this->gptr()); + } + } + } + } else { + c = traits_type::to_int_type(*this->gptr()); + } + if (this->eback() == &buf) + this->setg(0, 0, 0); + return c; +} + +/////////////////////////////////////////////////////////////////////////////// +template +typename basic_filebuf::int_type +basic_filebuf::pbackfail(int_type c) { + if (_M_file && this->eback() < this->gptr()) { + if (traits_type::eq_int_type(c, traits_type::eof())) { + this->gbump(-1); + return traits_type::not_eof(c); + } + if ((_M_om & std::ios_base::out) || + traits_type::eq(traits_type::to_char_type(c), this->gptr()[-1])) { + this->gbump(-1); + *this->gptr() = traits_type::to_char_type(c); + return c; + } + } + return traits_type::eof(); +} + +/////////////////////////////////////////////////////////////////////////////// +template +typename basic_filebuf::int_type +basic_filebuf::overflow(int_type c) { + if (_M_file == nullptr) + return traits_type::eof(); + _M_write_mode(); + char_type buf; + char_type* pb_save = this->pbase(); + char_type* epb_save = this->epptr(); + if (!traits_type::eq_int_type(c, traits_type::eof())) { + if (this->pptr() == nullptr) + this->setp(&buf, &buf+1); + *this->pptr() = traits_type::to_char_type(c); + this->pbump(1); + } + if (this->pptr() != this->pbase()) { + if (_M_always_noconv) { + size_t nmemb = static_cast(this->pptr() - this->pbase()); + if (fwrite(this->pbase(), sizeof(char_type), + nmemb, _M_file) != nmemb) + return traits_type::eof(); + } else { + char* extbe = _M_extbuf; + std::codecvt_base::result r; + do { + if (!_M_cv) + throw std::bad_cast(); + const char_type* e; + r = _M_cv->out(_M_st, this->pbase(), this->pptr(), e, + _M_extbuf, _M_extbuf + _M_ebs, extbe); + if (e == this->pbase()) + return traits_type::eof(); + if (r == std::codecvt_base::noconv) { + size_t nmemb = static_cast + (this->pptr() - this->pbase()); + if (fwrite(this->pbase(), 1, nmemb, _M_file) != nmemb) + return traits_type::eof(); + } else if (r == std::codecvt_base::ok || + r == std::codecvt_base::partial) { + size_t nmemb = static_cast(extbe - _M_extbuf); + if (fwrite(_M_extbuf, 1, nmemb, _M_file) != nmemb) + return traits_type::eof(); + if (r == std::codecvt_base::partial) { + this->setp(const_cast(e), + this->pptr()); + this->pbump(this->epptr() - this->pbase()); + } + } else { + return traits_type::eof(); + } + } while (r == std::codecvt_base::partial); + } + this->setp(pb_save, epb_save); + } + return traits_type::not_eof(c); +} + +/////////////////////////////////////////////////////////////////////////////// +template +std::basic_streambuf* +basic_filebuf::setbuf(char_type* s, std::streamsize n) { + this->setg(0, 0, 0); + this->setp(0, 0); + if (_M_owns_eb) + delete [] _M_extbuf; + if (_M_owns_ib) + delete [] _M_intbuf; + _M_ebs = n; + if (_M_ebs > sizeof(_M_extbuf_min)) { + if (_M_always_noconv && s) { + _M_extbuf = reinterpret_cast(s); + _M_owns_eb = false; + } else { + _M_extbuf = new char[_M_ebs]; + _M_owns_eb = true; + } + } else { + _M_extbuf = _M_extbuf_min; + _M_ebs = sizeof(_M_extbuf_min); + _M_owns_eb = false; + } + if (!_M_always_noconv) { + _M_ibs = std::max(n, sizeof(_M_extbuf_min)); + if (s && _M_ibs >= sizeof(_M_extbuf_min)) { + _M_intbuf = s; + _M_owns_ib = false; + } else { + _M_intbuf = new char_type[_M_ibs]; + _M_owns_ib = true; + } + } else { + _M_ibs = 0; + _M_intbuf = 0; + _M_owns_ib = false; + } + return this; +} + +/////////////////////////////////////////////////////////////////////////////// +template +typename basic_filebuf::pos_type +basic_filebuf::seekoff(off_type off, std::ios_base::seekdir way, + std::ios_base::openmode) { + if (!_M_cv) + throw std::bad_cast(); + int width = _M_cv->encoding(); + if (_M_file == nullptr || (width <= 0 && off != 0) || sync()) + return pos_type(off_type(-1)); + // width > 0 || off == 0 + int whence; + switch (way) { + case std::ios_base::beg: + whence = SEEK_SET; + break; + case std::ios_base::cur: + whence = SEEK_CUR; + break; + case std::ios_base::end: + whence = SEEK_END; + break; + default: + return pos_type(off_type(-1)); + } +#if _WIN32 + if (fseek(_M_file, width > 0 ? width * off : 0, whence)) + return pos_type(off_type(-1)); + pos_type r = ftell(_M_file); +#else + if (fseeko(_M_file, width > 0 ? width * off : 0, whence)) + return pos_type(off_type(-1)); + pos_type r = ftello(_M_file); +#endif + r.state(_M_st); + return r; +} + +/////////////////////////////////////////////////////////////////////////////// +template +typename basic_filebuf::pos_type +basic_filebuf::seekpos(pos_type sp, std::ios_base::openmode) { + if (_M_file == nullptr || sync()) + return pos_type(off_type(-1)); +#if _WIN32 + if (fseek(_M_file, sp, SEEK_SET)) + return pos_type(off_type(-1)); +#else + if (fseeko(_M_file, sp, SEEK_SET)) + return pos_type(off_type(-1)); +#endif + _M_st = sp.state(); + return sp; +} + +/////////////////////////////////////////////////////////////////////////////// +template +int +basic_filebuf::sync() { + if (_M_file == nullptr) + return 0; + if (!_M_cv) + throw std::bad_cast(); + if (_M_cm & std::ios_base::out) { + if (this->pptr() != this->pbase()) + if (overflow() == traits_type::eof()) + return -1; + std::codecvt_base::result r; + do { + char* extbe; + r = _M_cv->unshift(_M_st, _M_extbuf, _M_extbuf + _M_ebs, extbe); + size_t nmemb = static_cast(extbe - _M_extbuf); + if (fwrite(_M_extbuf, 1, nmemb, _M_file) != nmemb) + return -1; + } while (r == std::codecvt_base::partial); + if (r == std::codecvt_base::error) + return -1; + if (fflush(_M_file)) + return -1; + } else if (_M_cm & std::ios_base::in) { + off_type c; + state_type state = _M_st_last; + bool update_st = false; + if (_M_always_noconv) { + c = this->egptr() - this->gptr(); + } else { + int width = _M_cv->encoding(); + c = _M_extbufend - _M_extbufnext; + if (width > 0) { + c += width * (this->egptr() - this->gptr()); + } else { + if (this->gptr() != this->egptr()) { + const int off = _M_cv->length(state, _M_extbuf, + _M_extbufnext, + this->gptr() - this->eback()); + c += _M_extbufnext - _M_extbuf - off; + update_st = true; + } + } + } +#if _WIN32 + if (fseek(_M_file_, -c, SEEK_CUR)) + return -1; +#else + if (fseeko(_M_file, -c, SEEK_CUR)) + return -1; +#endif + if (update_st) + _M_st = state; + _M_extbufnext = _M_extbufend = _M_extbuf; + this->setg(0, 0, 0); + _M_cm = std::ios_base::openmode(0); + } + return 0; +} + +/////////////////////////////////////////////////////////////////////////////// +template +void +basic_filebuf::imbue(const std::locale& loc) { + sync(); + _M_cv = &std::use_facet >(loc); + bool old_anc = _M_always_noconv; + _M_always_noconv = _M_cv->always_noconv(); + if (old_anc != _M_always_noconv) { + this->setg(0, 0, 0); + this->setp(0, 0); + // invariant, char_type is char, else we couldn't get here + // need to dump _M_intbuf + if (_M_always_noconv) { + if (_M_owns_eb) + delete [] _M_extbuf; + _M_owns_eb = _M_owns_ib; + _M_ebs = _M_ibs; + _M_extbuf = reinterpret_cast(_M_intbuf); + _M_ibs = 0; + _M_intbuf = nullptr; + _M_owns_ib = false; + } else { // need to obtain an _M_intbuf. + // If _M_extbuf is user-supplied, use it, else new _M_intbuf + if (!_M_owns_eb && _M_extbuf != _M_extbuf_min) { + _M_ibs = _M_ebs; + _M_intbuf = reinterpret_cast(_M_extbuf); + _M_owns_ib = false; + _M_extbuf = new char[_M_ebs]; + _M_owns_eb = true; + } else { + _M_ibs = _M_ebs; + _M_intbuf = new char_type[_M_ibs]; + _M_owns_ib = true; + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////// +template +bool +basic_filebuf::_M_read_mode() { + if (!(_M_cm & std::ios_base::in)) { + this->setp(0, 0); + if (_M_always_noconv) + this->setg(reinterpret_cast(_M_extbuf), + reinterpret_cast(_M_extbuf) + _M_ebs, + reinterpret_cast(_M_extbuf) + _M_ebs); + else + this->setg(_M_intbuf, _M_intbuf + _M_ibs, _M_intbuf + _M_ibs); + _M_cm = std::ios_base::in; + return true; + } + return false; +} + +/////////////////////////////////////////////////////////////////////////////// +template +void +basic_filebuf::_M_write_mode() { + if (!(_M_cm & std::ios_base::out)) { + this->setg(0, 0, 0); + if (_M_ebs > sizeof(_M_extbuf_min)) { + if (_M_always_noconv) + this->setp(reinterpret_cast(_M_extbuf), + reinterpret_cast(_M_extbuf) + + (_M_ebs - 1)); + else + this->setp(_M_intbuf, _M_intbuf + (_M_ibs - 1)); + } else { + this->setp(0, 0); + } + _M_cm = std::ios_base::out; + } +} + +/////////////////////////////////////////////////////////////////////////////// +} + +/////////////////////////////////////////////////////////////////////////////// +#endif // KALDI_UTIL_BASIC_FILEBUF_H_ + +/////////////////////////////////////////////////////////////////////////////// + +/* + * ============================================================================ + * libc++ License + * ============================================================================ + * + * The libc++ library is dual licensed under both the University of Illinois + * "BSD-Like" license and the MIT license. As a user of this code you may + * choose to use it under either license. As a contributor, you agree to allow + * your code to be used under both. + * + * Full text of the relevant licenses is included below. + * + * ============================================================================ + * + * University of Illinois/NCSA + * Open Source License + * + * Copyright (c) 2009-2014 by the contributors listed in CREDITS.TXT (included below) + * + * All rights reserved. + * + * Developed by: + * + * LLVM Team + * + * University of Illinois at Urbana-Champaign + * + * http://llvm.org + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of + * this software and associated documentation files (the "Software"), to deal with + * the Software without restriction, including without limitation the rights to + * use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies + * of the Software, and to permit persons to whom the Software is furnished to do + * so, subject to the following conditions: + * + * * Redistributions of source code must retain the above copyright notice, + * this list of conditions and the following disclaimers. + * + * * Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimers in the + * documentation and/or other materials provided with the distribution. + * + * * Neither the names of the LLVM Team, University of Illinois at + * Urbana-Champaign, nor the names of its contributors may be used to + * endorse or promote products derived from this Software without specific + * prior written permission. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS + * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE + * SOFTWARE. + * + * ============================================================================== + * + * Copyright (c) 2009-2014 by the contributors listed in CREDITS.TXT (included below) + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * ============================================================================== + * + * This file is a partial list of people who have contributed to the LLVM/libc++ + * project. If you have contributed a patch or made some other contribution to + * LLVM/libc++, please submit a patch to this file to add yourself, and it will be + * done! + * + * The list is sorted by surname and formatted to allow easy grepping and + * beautification by scripts. The fields are: name (N), email (E), web-address + * (W), PGP key ID and fingerprint (P), description (D), and snail-mail address + * (S). + * + * N: Saleem Abdulrasool + * E: compnerd@compnerd.org + * D: Minor patches and Linux fixes. + * + * N: Dimitry Andric + * E: dimitry@andric.com + * D: Visibility fixes, minor FreeBSD portability patches. + * + * N: Holger Arnold + * E: holgerar@gmail.com + * D: Minor fix. + * + * N: Ruben Van Boxem + * E: vanboxem dot ruben at gmail dot com + * D: Initial Windows patches. + * + * N: David Chisnall + * E: theraven at theravensnest dot org + * D: FreeBSD and Solaris ports, libcxxrt support, some atomics work. + * + * N: Marshall Clow + * E: mclow.lists@gmail.com + * E: marshall@idio.com + * D: C++14 support, patches and bug fixes. + * + * N: Bill Fisher + * E: william.w.fisher@gmail.com + * D: Regex bug fixes. + * + * N: Matthew Dempsky + * E: matthew@dempsky.org + * D: Minor patches and bug fixes. + * + * N: Google Inc. + * D: Copyright owner and contributor of the CityHash algorithm + * + * N: Howard Hinnant + * E: hhinnant@apple.com + * D: Architect and primary author of libc++ + * + * N: Hyeon-bin Jeong + * E: tuhertz@gmail.com + * D: Minor patches and bug fixes. + * + * N: Argyrios Kyrtzidis + * E: kyrtzidis@apple.com + * D: Bug fixes. + * + * N: Bruce Mitchener, Jr. + * E: bruce.mitchener@gmail.com + * D: Emscripten-related changes. + * + * N: Michel Morin + * E: mimomorin@gmail.com + * D: Minor patches to is_convertible. + * + * N: Andrew Morrow + * E: andrew.c.morrow@gmail.com + * D: Minor patches and Linux fixes. + * + * N: Arvid Picciani + * E: aep at exys dot org + * D: Minor patches and musl port. + * + * N: Bjorn Reese + * E: breese@users.sourceforge.net + * D: Initial regex prototype + * + * N: Nico Rieck + * E: nico.rieck@gmail.com + * D: Windows fixes + * + * N: Jonathan Sauer + * D: Minor patches, mostly related to constexpr + * + * N: Craig Silverstein + * E: csilvers@google.com + * D: Implemented Cityhash as the string hash function on 64-bit machines + * + * N: Richard Smith + * D: Minor patches. + * + * N: Joerg Sonnenberger + * E: joerg@NetBSD.org + * D: NetBSD port. + * + * N: Stephan Tolksdorf + * E: st@quanttec.com + * D: Minor fix + * + * N: Michael van der Westhuizen + * E: r1mikey at gmail dot com + * + * N: Klaas de Vries + * E: klaas at klaasgaaf dot nl + * D: Minor bug fix. + * + * N: Zhang Xiongpang + * E: zhangxiongpang@gmail.com + * D: Minor patches and bug fixes. + * + * N: Xing Xue + * E: xingxue@ca.ibm.com + * D: AIX port + * + * N: Zhihao Yuan + * E: lichray@gmail.com + * D: Standard compatibility fixes. + * + * N: Jeffrey Yasskin + * E: jyasskin@gmail.com + * E: jyasskin@google.com + * D: Linux fixes. + */ diff --git a/speechx/speechx/kaldi/util/common-utils.h b/speechx/speechx/kaldi/util/common-utils.h new file mode 100644 index 00000000..cfb0c255 --- /dev/null +++ b/speechx/speechx/kaldi/util/common-utils.h @@ -0,0 +1,31 @@ +// util/common-utils.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_COMMON_UTILS_H_ +#define KALDI_UTIL_COMMON_UTILS_H_ + +#include "base/kaldi-common.h" +#include "util/parse-options.h" +#include "util/kaldi-io.h" +#include "util/simple-io-funcs.h" +#include "util/kaldi-holder.h" +#include "util/kaldi-table.h" +#include "util/table-types.h" +#include "util/text-utils.h" + +#endif // KALDI_UTIL_COMMON_UTILS_H_ diff --git a/speechx/speechx/kaldi/util/const-integer-set-inl.h b/speechx/speechx/kaldi/util/const-integer-set-inl.h new file mode 100644 index 00000000..32560535 --- /dev/null +++ b/speechx/speechx/kaldi/util/const-integer-set-inl.h @@ -0,0 +1,91 @@ +// util/const-integer-set-inl.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_CONST_INTEGER_SET_INL_H_ +#define KALDI_UTIL_CONST_INTEGER_SET_INL_H_ + +// Do not include this file directly. It is included by const-integer-set.h + + +namespace kaldi { + +template +void ConstIntegerSet::InitInternal() { + KALDI_ASSERT_IS_INTEGER_TYPE(I); + quick_set_.clear(); // just in case we previously had data. + if (slow_set_.size() == 0) { + lowest_member_=(I) 1; + highest_member_=(I) 0; + contiguous_ = false; + quick_ = false; + } else { + lowest_member_ = slow_set_.front(); + highest_member_ = slow_set_.back(); + size_t range = highest_member_ + 1 - lowest_member_; + if (range == slow_set_.size()) { + contiguous_ = true; + quick_= false; + } else { + contiguous_ = false; + // If it would be more compact to store as bool + if (range < slow_set_.size() * 8 * sizeof(I)) { + // (assuming 1 bit per element)... + quick_set_.resize(range, false); + for (size_t i = 0;i < slow_set_.size();i++) + quick_set_[slow_set_[i] - lowest_member_] = true; + quick_ = true; + } else { + quick_ = false; + } + } + } +} + +template +int ConstIntegerSet::count(I i) const { + if (i < lowest_member_ || i > highest_member_) { + return 0; + } else { + if (contiguous_) return true; + if (quick_) { + return (quick_set_[i-lowest_member_] ? 1 : 0); + } else { + bool ans = std::binary_search(slow_set_.begin(), slow_set_.end(), i); + return (ans ? 1 : 0); + } + } +} + +template +void ConstIntegerSet::Write(std::ostream &os, bool binary) const { + WriteIntegerVector(os, binary, slow_set_); +} + +template +void ConstIntegerSet::Read(std::istream &is, bool binary) { + ReadIntegerVector(is, binary, &slow_set_); + InitInternal(); +} + + + +} // end namespace kaldi + +#endif // KALDI_UTIL_CONST_INTEGER_SET_INL_H_ diff --git a/speechx/speechx/kaldi/util/const-integer-set.h b/speechx/speechx/kaldi/util/const-integer-set.h new file mode 100644 index 00000000..bb10a504 --- /dev/null +++ b/speechx/speechx/kaldi/util/const-integer-set.h @@ -0,0 +1,96 @@ +// util/const-integer-set.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_CONST_INTEGER_SET_H_ +#define KALDI_UTIL_CONST_INTEGER_SET_H_ +#include +#include +#include +#include +#include +#include "util/stl-utils.h" + + /* ConstIntegerSet is a way to efficiently test whether something is in a + supplied set of integers. It can be initialized from a vector or set, but + never changed after that. It either uses a sorted vector or an array of + bool, depending on the input. It behaves like a const version of an STL set, with + only a subset of the functionality, except all the member functions are + upper-case. + + Note that we could get rid of the member slow_set_, but we'd have to + do more work to implement an iterator type. This would save memory. + */ + +namespace kaldi { + +template class ConstIntegerSet { + public: + ConstIntegerSet(): lowest_member_(1), highest_member_(0) { } + + void Init(const std::vector &input) { + slow_set_ = input; + SortAndUniq(&slow_set_); + InitInternal(); + } + + void Init(const std::set &input) { + CopySetToVector(input, &slow_set_); + InitInternal(); + } + + explicit ConstIntegerSet(const std::vector &input): slow_set_(input) { + SortAndUniq(&slow_set_); + InitInternal(); + } + explicit ConstIntegerSet(const std::set &input) { + CopySetToVector(input, &slow_set_); + InitInternal(); + } + explicit ConstIntegerSet(const ConstIntegerSet &other): + slow_set_(other.slow_set_) { + InitInternal(); + } + + int count(I i) const; // returns 1 or 0. + + typedef typename std::vector::const_iterator iterator; + iterator begin() const { return slow_set_.begin(); } + iterator end() const { return slow_set_.end(); } + size_t size() const { return slow_set_.size(); } + bool empty() const { return slow_set_.empty(); } + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + private: + I lowest_member_; + I highest_member_; + bool contiguous_; + bool quick_; + std::vector quick_set_; + std::vector slow_set_; + void InitInternal(); +}; + +} // end namespace kaldi + +#include "util/const-integer-set-inl.h" + +#endif // KALDI_UTIL_CONST_INTEGER_SET_H_ diff --git a/speechx/speechx/kaldi/util/edit-distance-inl.h b/speechx/speechx/kaldi/util/edit-distance-inl.h new file mode 100644 index 00000000..3304b27d --- /dev/null +++ b/speechx/speechx/kaldi/util/edit-distance-inl.h @@ -0,0 +1,200 @@ +// util/edit-distance-inl.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_EDIT_DISTANCE_INL_H_ +#define KALDI_UTIL_EDIT_DISTANCE_INL_H_ +#include +#include +#include +#include "util/stl-utils.h" + +namespace kaldi { + +template +int32 LevenshteinEditDistance(const std::vector &a, + const std::vector &b) { + // Algorithm: + // write A and B for the sequences, with elements a_0 .. + // let |A| = M and |B| = N be the lengths, and have + // elements a_0 ... a_{M-1} and b_0 ... b_{N-1}. + // We are computing the recursion + // E(m, n) = min( E(m-1, n-1) + (1-delta(a_{m-1}, b_{n-1})), + // E(m-1, n) + 1, + // E(m, n-1) + 1). + // where E(m, n) is defined for m = 0..M and n = 0..N and out-of- + // bounds quantities are considered to be infinity (i.e. the + // recursion does not visit them). + + // We do this computation using a vector e of size N+1. + // The outer iterations range over m = 0..M. + + int M = a.size(), N = b.size(); + std::vector e(N+1); + std::vector e_tmp(N+1); + // initialize e. + for (size_t i = 0; i < e.size(); i++) + e[i] = i; + for (int32 m = 1; m <= M; m++) { + // computing E(m, .) from E(m-1, .) + // handle special case n = 0: + e_tmp[0] = e[0] + 1; + + for (int32 n = 1; n <= N; n++) { + int32 term1 = e[n-1] + (a[m-1] == b[n-1] ? 0 : 1); + int32 term2 = e[n] + 1; + int32 term3 = e_tmp[n-1] + 1; + e_tmp[n] = std::min(term1, std::min(term2, term3)); + } + e = e_tmp; + } + return e.back(); +} +// +struct error_stats { + int32 ins_num; + int32 del_num; + int32 sub_num; + int32 total_cost; // minimum total cost to the current alignment. +}; +// Note that both hyp and ref should not contain noise word in +// the following implementation. + +template +int32 LevenshteinEditDistance(const std::vector &ref, + const std::vector &hyp, + int32 *ins, int32 *del, int32 *sub) { + // temp sequence to remember error type and stats. + std::vector e(ref.size()+1); + std::vector cur_e(ref.size()+1); + // initialize the first hypothesis aligned to the reference at each + // position:[hyp_index =0][ref_index] + for (size_t i =0; i < e.size(); i ++) { + e[i].ins_num = 0; + e[i].sub_num = 0; + e[i].del_num = i; + e[i].total_cost = i; + } + + // for other alignments + for (size_t hyp_index = 1; hyp_index <= hyp.size(); hyp_index ++) { + cur_e[0] = e[0]; + cur_e[0].ins_num++; + cur_e[0].total_cost++; + for (size_t ref_index = 1; ref_index <= ref.size(); ref_index ++) { + int32 ins_err = e[ref_index].total_cost + 1; + int32 del_err = cur_e[ref_index-1].total_cost + 1; + int32 sub_err = e[ref_index-1].total_cost; + if (hyp[hyp_index-1] != ref[ref_index-1]) + sub_err++; + + if (sub_err < ins_err && sub_err < del_err) { + cur_e[ref_index] =e[ref_index-1]; + if (hyp[hyp_index-1] != ref[ref_index-1]) + cur_e[ref_index].sub_num++; // substitution error should be increased + cur_e[ref_index].total_cost = sub_err; + } else if (del_err < ins_err) { + cur_e[ref_index] = cur_e[ref_index-1]; + cur_e[ref_index].total_cost = del_err; + cur_e[ref_index].del_num++; // deletion number is increased. + } else { + cur_e[ref_index] = e[ref_index]; + cur_e[ref_index].total_cost = ins_err; + cur_e[ref_index].ins_num++; // insertion number is increased. + } + } + e = cur_e; // alternate for the next recursion. + } + size_t ref_index = e.size()-1; + *ins = e[ref_index].ins_num, *del = + e[ref_index].del_num, *sub = e[ref_index].sub_num; + return e[ref_index].total_cost; +} + +template +int32 LevenshteinAlignment(const std::vector &a, + const std::vector &b, + T eps_symbol, + std::vector > *output) { + // Check inputs: + { + KALDI_ASSERT(output != NULL); + for (size_t i = 0; i < a.size(); i++) KALDI_ASSERT(a[i] != eps_symbol); + for (size_t i = 0; i < b.size(); i++) KALDI_ASSERT(b[i] != eps_symbol); + } + output->clear(); + // This is very memory-inefficiently implemented using a vector of vectors. + size_t M = a.size(), N = b.size(); + size_t m, n; + std::vector > e(M+1); + for (m = 0; m <=M; m++) e[m].resize(N+1); + for (n = 0; n <= N; n++) + e[0][n] = n; + for (m = 1; m <= M; m++) { + e[m][0] = e[m-1][0] + 1; + for (n = 1; n <= N; n++) { + int32 sub_or_ok = e[m-1][n-1] + (a[m-1] == b[n-1] ? 0 : 1); + int32 del = e[m-1][n] + 1; // assumes a == ref, b == hyp. + int32 ins = e[m][n-1] + 1; + e[m][n] = std::min(sub_or_ok, std::min(del, ins)); + } + } + // get time-reversed output first: trace back. + m = M; + n = N; + while (m != 0 || n != 0) { + size_t last_m, last_n; + if (m == 0) { + last_m = m; + last_n = n-1; + } else if (n == 0) { + last_m = m-1; + last_n = n; + } else { + int32 sub_or_ok = e[m-1][n-1] + (a[m-1] == b[n-1] ? 0 : 1); + int32 del = e[m-1][n] + 1; // assumes a == ref, b == hyp. + int32 ins = e[m][n-1] + 1; + // choose sub_or_ok if all else equal. + if (sub_or_ok <= std::min(del, ins)) { + last_m = m-1; + last_n = n-1; + } else { + if (del <= ins) { // choose del over ins if equal. + last_m = m-1; + last_n = n; + } else { + last_m = m; + last_n = n-1; + } + } + } + T a_sym, b_sym; + a_sym = (last_m == m ? eps_symbol : a[last_m]); + b_sym = (last_n == n ? eps_symbol : b[last_n]); + output->push_back(std::make_pair(a_sym, b_sym)); + m = last_m; + n = last_n; + } + ReverseVector(output); + return e[M][N]; +} + + +} // end namespace kaldi + +#endif // KALDI_UTIL_EDIT_DISTANCE_INL_H_ diff --git a/speechx/speechx/kaldi/util/edit-distance.h b/speechx/speechx/kaldi/util/edit-distance.h new file mode 100644 index 00000000..5eac4aea --- /dev/null +++ b/speechx/speechx/kaldi/util/edit-distance.h @@ -0,0 +1,64 @@ +// util/edit-distance.h + +// Copyright 2009-2011 Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_EDIT_DISTANCE_H_ +#define KALDI_UTIL_EDIT_DISTANCE_H_ +#include +#include +#include +#include +#include +#include +#include "util/edit-distance-inl.h" +#include "base/kaldi-types.h" + +namespace kaldi { + +// Compute the edit-distance between two strings. +template +int32 LevenshteinEditDistance(const std::vector &a, + const std::vector &b); + + +// edit distance calculation with conventional method. +// note: noise word must be filtered out from the hypothesis and +// reference sequence +// before the following procedure conducted. +template +int32 LevenshteinEditDistance(const std::vector &ref, + const std::vector &hyp, + int32 *ins, int32 *del, int32 *sub); + +// This version of the edit-distance computation outputs the alignment +// between the two. This is a vector of pairs of (symbol a, symbol b). +// The epsilon symbol (eps_symbol) must not occur in sequences a or b. +// Where one aligned to no symbol in the other (insertion or deletion), +// epsilon will be the corresponding member of the pair. +// It returns the edit-distance between the two strings. + +template +int32 LevenshteinAlignment(const std::vector &a, + const std::vector &b, + T eps_symbol, + std::vector > *output); + +} // end namespace kaldi + +#endif // KALDI_UTIL_EDIT_DISTANCE_H_ diff --git a/speechx/speechx/kaldi/util/hash-list-inl.h b/speechx/speechx/kaldi/util/hash-list-inl.h new file mode 100644 index 00000000..da6165af --- /dev/null +++ b/speechx/speechx/kaldi/util/hash-list-inl.h @@ -0,0 +1,194 @@ +// util/hash-list-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_HASH_LIST_INL_H_ +#define KALDI_UTIL_HASH_LIST_INL_H_ + +// Do not include this file directly. It is included by fast-hash.h + + +namespace kaldi { + +template HashList::HashList() { + list_head_ = NULL; + bucket_list_tail_ = static_cast(-1); // invalid. + hash_size_ = 0; + freed_head_ = NULL; +} + +template void HashList::SetSize(size_t size) { + hash_size_ = size; + KALDI_ASSERT(list_head_ == NULL && + bucket_list_tail_ == static_cast(-1)); // make sure empty. + if (size > buckets_.size()) + buckets_.resize(size, HashBucket(0, NULL)); +} + +template +typename HashList::Elem* HashList::Clear() { + // Clears the hashtable and gives ownership of the currently contained list + // to the user. + for (size_t cur_bucket = bucket_list_tail_; + cur_bucket != static_cast(-1); + cur_bucket = buckets_[cur_bucket].prev_bucket) { + buckets_[cur_bucket].last_elem = NULL; // this is how we indicate "empty". + } + bucket_list_tail_ = static_cast(-1); + Elem *ans = list_head_; + list_head_ = NULL; + return ans; +} + +template +const typename HashList::Elem* HashList::GetList() const { + return list_head_; +} + +template +inline void HashList::Delete(Elem *e) { + e->tail = freed_head_; + freed_head_ = e; +} + +template +inline typename HashList::Elem* HashList::Find(I key) { + size_t index = (static_cast(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + if (bucket.last_elem == NULL) { + return NULL; // empty bucket. + } else { + Elem *head = (bucket.prev_bucket == static_cast(-1) ? + list_head_ : + buckets_[bucket.prev_bucket].last_elem->tail), + *tail = bucket.last_elem->tail; + for (Elem *e = head; e != tail; e = e->tail) + if (e->key == key) return e; + return NULL; // Not found. + } +} + +template +inline typename HashList::Elem* HashList::New() { + if (freed_head_) { + Elem *ans = freed_head_; + freed_head_ = freed_head_->tail; + return ans; + } else { + Elem *tmp = new Elem[allocate_block_size_]; + for (size_t i = 0; i+1 < allocate_block_size_; i++) + tmp[i].tail = tmp+i+1; + tmp[allocate_block_size_-1].tail = NULL; + freed_head_ = tmp; + allocated_.push_back(tmp); + return this->New(); + } +} + +template +HashList::~HashList() { + // First test whether we had any memory leak within the + // HashList, i.e. things for which the user did not call Delete(). + size_t num_in_list = 0, num_allocated = 0; + for (Elem *e = freed_head_; e != NULL; e = e->tail) + num_in_list++; + for (size_t i = 0; i < allocated_.size(); i++) { + num_allocated += allocate_block_size_; + delete[] allocated_[i]; + } + if (num_in_list != num_allocated) { + KALDI_WARN << "Possible memory leak: " << num_in_list + << " != " << num_allocated + << ": you might have forgotten to call Delete on " + << "some Elems"; + } +} + +template +inline typename HashList::Elem* HashList::Insert(I key, T val) { + size_t index = (static_cast(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + // Check the element is existing or not. + if (bucket.last_elem != NULL) { + Elem *head = (bucket.prev_bucket == static_cast(-1) ? + list_head_ : + buckets_[bucket.prev_bucket].last_elem->tail), + *tail = bucket.last_elem->tail; + for (Elem *e = head; e != tail; e = e->tail) + if (e->key == key) return e; + } + + // This is a new element. Insert it. + Elem *elem = New(); + elem->key = key; + elem->val = val; + if (bucket.last_elem == NULL) { // Unoccupied bucket. Insert at + // head of bucket list (which is tail of regular list, they go in + // opposite directions). + if (bucket_list_tail_ == static_cast(-1)) { + // list was empty so this is the first elem. + KALDI_ASSERT(list_head_ == NULL); + list_head_ = elem; + } else { + // link in to the chain of Elems + buckets_[bucket_list_tail_].last_elem->tail = elem; + } + elem->tail = NULL; + bucket.last_elem = elem; + bucket.prev_bucket = bucket_list_tail_; + bucket_list_tail_ = index; + } else { + // Already-occupied bucket. Insert at tail of list of elements within + // the bucket. + elem->tail = bucket.last_elem->tail; + bucket.last_elem->tail = elem; + bucket.last_elem = elem; + } + return elem; +} + +template +void HashList::InsertMore(I key, T val) { + size_t index = (static_cast(key) % hash_size_); + HashBucket &bucket = buckets_[index]; + Elem *elem = New(); + elem->key = key; + elem->val = val; + + KALDI_ASSERT(bucket.last_elem != NULL); // assume one element is already here + if (bucket.last_elem->key == key) { // standard behavior: add as last element + elem->tail = bucket.last_elem->tail; + bucket.last_elem->tail = elem; + bucket.last_elem = elem; + return; + } + Elem *e = (bucket.prev_bucket == static_cast(-1) ? + list_head_ : buckets_[bucket.prev_bucket].last_elem->tail); + // find place to insert in linked list + while (e != bucket.last_elem->tail && e->key != key) e = e->tail; + KALDI_ASSERT(e->key == key); // not found? - should not happen + elem->tail = e->tail; + e->tail = elem; +} + + +} // end namespace kaldi + +#endif // KALDI_UTIL_HASH_LIST_INL_H_ diff --git a/speechx/speechx/kaldi/util/hash-list.h b/speechx/speechx/kaldi/util/hash-list.h new file mode 100644 index 00000000..9ae0043f --- /dev/null +++ b/speechx/speechx/kaldi/util/hash-list.h @@ -0,0 +1,147 @@ +// util/hash-list.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_HASH_LIST_H_ +#define KALDI_UTIL_HASH_LIST_H_ +#include +#include +#include +#include +#include +#include "util/stl-utils.h" + + +/* This header provides utilities for a structure that's used in a decoder (but + is quite generic in nature so we implement and test it separately). + Basically it's a singly-linked list, but implemented in such a way that we + can quickly search for elements in the list. We give it a slightly richer + interface than just a hash and a list. The idea is that we want to separate + the hash part and the list part: basically, in the decoder, we want to have a + single hash for the current frame and the next frame, because by the time we + need to access the hash for the next frame we no longer need the hash for the + previous frame. So we have an operation that clears the hash but leaves the + list structure intact. We also control memory management inside this object, + to avoid repeated new's/deletes. + + See hash-list-test.cc for an example of how to use this object. +*/ + + +namespace kaldi { + +template class HashList { + public: + struct Elem { + I key; + T val; + Elem *tail; + }; + + /// Constructor takes no arguments. + /// Call SetSize to inform it of the likely size. + HashList(); + + /// Clears the hash and gives the head of the current list to the user; + /// ownership is transferred to the user (the user must call Delete() + /// for each element in the list, at his/her leisure). + Elem *Clear(); + + /// Gives the head of the current list to the user. Ownership retained in the + /// class. Caution: in December 2013 the return type was changed to const + /// Elem* and this function was made const. You may need to change some types + /// of local Elem* variables to const if this produces compilation errors. + const Elem *GetList() const; + + /// Think of this like delete(). It is to be called for each Elem in turn + /// after you "obtained ownership" by doing Clear(). This is not the opposite + /// of. Insert, it is the opposite of New. It's really a memory operation. + inline void Delete(Elem *e); + + /// This should probably not be needed to be called directly by the user. + /// Think of it as opposite + /// to Delete(); + inline Elem *New(); + + /// Find tries to find this element in the current list using the hashtable. + /// It returns NULL if not present. The Elem it returns is not owned by the + /// user, it is part of the internal list owned by this object, but the user + /// is free to modify the "val" element. + inline Elem *Find(I key); + + /// Insert inserts a new element into the hashtable/stored list. + /// Because element keys in a hashtable are unique, this operation checks + /// whether each inserted element has a key equivalent to the one of an + /// element already in the hashtable. If so, the element is not inserted, + /// returning an pointer to this existing element. + inline Elem *Insert(I key, T val); + + /// Insert inserts another element with same key into the hashtable/ + /// stored list. + /// By calling this, the user asserts that one element with that key is + /// already present. + /// We insert it that way, that all elements with the same key + /// follow each other. + /// Find() will return the first one of the elements with the same key. + inline void InsertMore(I key, T val); + + /// SetSize tells the object how many hash buckets to allocate (should + /// typically be at least twice the number of objects we expect to go in the + /// structure, for fastest performance). It must be called while the hash + /// is empty (e.g. after Clear() or after initializing the object, but before + /// adding anything to the hash. + void SetSize(size_t sz); + + /// Returns current number of hash buckets. + inline size_t Size() { return hash_size_; } + + ~HashList(); + private: + + struct HashBucket { + size_t prev_bucket; // index to next bucket (-1 if list tail). Note: + // list of buckets goes in opposite direction to list of Elems. + Elem *last_elem; // pointer to last element in this bucket (NULL if empty) + inline HashBucket(size_t i, Elem *e): prev_bucket(i), last_elem(e) {} + }; + + Elem *list_head_; // head of currently stored list. + size_t bucket_list_tail_; // tail of list of active hash buckets. + + size_t hash_size_; // number of hash buckets. + + std::vector buckets_; + + Elem *freed_head_; // head of list of currently freed elements. [ready for + // allocation] + + std::vector allocated_; // list of allocated blocks. + + static const size_t allocate_block_size_ = 1024; // Number of Elements to + // allocate in one block. Must be largish so storing allocated_ doesn't + // become a problem. +}; + + +} // end namespace kaldi + +#include "util/hash-list-inl.h" + +#endif // KALDI_UTIL_HASH_LIST_H_ diff --git a/speechx/speechx/kaldi/util/kaldi-cygwin-io-inl.h b/speechx/speechx/kaldi/util/kaldi-cygwin-io-inl.h new file mode 100644 index 00000000..8a3cd91a --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-cygwin-io-inl.h @@ -0,0 +1,129 @@ +// util/kaldi-cygwin-io-inl.h + +// Copyright 2015 Smart Action Company LLC (author: Kirill Katsnelson) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_KALDI_CYGWIN_IO_INL_H_ +#define KALDI_UTIL_KALDI_CYGWIN_IO_INL_H_ + +#ifndef _MSC_VER +#error This is a Windows-compatibility file. Something went wery wrong. +#endif + +#include + +// This file is included only into kaldi-io.cc, and only if +// KALDI_CYGWIN_COMPAT is enabled. +// +// The routines map unix-ey paths passed to Windows programs from shell +// scripts in egs. Since shell scripts run under cygwin, they use cygwin's +// own mount table and a mapping to the file system. It is quite possible to +// create quite an intricate mapping that only own cygwin API would be able +// to untangle. Unfortunately, the API to map between filenames is not +// available to non-cygwin programs. Running cygpath for every file operation +// would as well be cumbersome. So this is only a simplistic path resolution, +// assuming that the default cygwin prefix /cygdrive is used, and that all +// resolved unix-style full paths end up prefixed with /cygdrive. This is +// quite a sensible approach. We'll also try to map /dev/null and /tmp/**, +// die on all other /dev/** and warn about all other rooted paths. + +namespace kaldi { + +static bool prefixp(const std::string& pfx, const std::string& str) { + return pfx.length() <= str.length() && + std::equal(pfx.begin(), pfx.end(), str.begin()); +} + +static std::string cygprefix("/cygdrive/"); + +static std::string MapCygwinPathNoTmp(const std::string &filename) { + // UNC(?), relative, native Windows and empty paths are ok already. + if (prefixp("//", filename) || !prefixp("/", filename)) + return filename; + + // /dev/... + if (filename == "/dev/null") + return "\\\\.\\nul"; + if (prefixp("/dev/", filename)) { + KALDI_ERR << "Unable to resolve path '" << filename + << "' - only have /dev/null here."; + return "\\\\.\\invalid"; + } + + // /cygdrive/?[/....] + int preflen = cygprefix.size(); + if (prefixp(cygprefix, filename) + && filename.size() >= preflen + 1 && isalpha(filename[preflen]) + && (filename.size() == preflen + 1 || filename[preflen + 1] == '/')) { + return std::string() + filename[preflen] + ':' + + (filename.size() > preflen + 1 ? filename.substr(preflen + 1) : "/"); + } + + KALDI_WARN << "Unable to resolve path '" << filename + << "' - cannot map unix prefix. " + << "Will go on, but breakage will likely ensue."; + return filename; +} + +// extern for unit testing. +std::string MapCygwinPath(const std::string &filename) { + // /tmp[/....] + if (filename != "/tmp" && !prefixp("/tmp/", filename)) { + return MapCygwinPathNoTmp(filename); + } + char *tmpdir = std::getenv("TMP"); + if (tmpdir == nullptr) + tmpdir = std::getenv("TEMP"); + if (tmpdir == nullptr) { + KALDI_ERR << "Unable to resolve path '" << filename + << "' - unable to find temporary directory. Set TMP."; + return filename; + } + // Map the value of tmpdir again, as cygwin environment actually may contain + // unix-style paths. + return MapCygwinPathNoTmp(std::string(tmpdir) + filename.substr(4)); +} + +// A popen implementation that passes the command line through cygwin +// bash.exe. This is necessary since some piped commands are cygwin links +// (e. g. fgrep is a soft link to grep), and some are #!-files, such as +// gunzip which is a shell script that invokes gzip, or kaldi's own run.pl +// which is a perl script. +// +// _popen uses cmd.exe or whatever shell is specified via the COMSPEC +// variable. Unfortunately, it adds a hardcoded " /c " to it, so we cannot +// just substitute the environment variable COMSPEC to point to bash.exe. +// Instead, quote the command and pass it to bash via its -c switch. +static FILE *CygwinCompatPopen(const char* command, const char* mode) { + // To speed up command launch marginally, optionally accept full path + // to bash.exe. This will not work if the path contains spaces, but + // no sane person would install cygwin into a space-ridden path. + const char* bash_exe = std::getenv("BASH_EXE"); + std::string qcmd(bash_exe != nullptr ? bash_exe : "bash.exe"); + qcmd += " -c \""; + for (; *command; ++command) { + if (*command == '\"') + qcmd += '\"'; + qcmd += *command; + } + qcmd += '\"'; + + return _popen(qcmd.c_str(), mode); +} + +} // namespace kaldi + +#endif // KALDI_UTIL_KALDI_CYGWIN_IO_INL_H_ diff --git a/speechx/speechx/kaldi/util/kaldi-holder-inl.h b/speechx/speechx/kaldi/util/kaldi-holder-inl.h new file mode 100644 index 00000000..134cdd93 --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-holder-inl.h @@ -0,0 +1,922 @@ +// util/kaldi-holder-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_KALDI_HOLDER_INL_H_ +#define KALDI_UTIL_KALDI_HOLDER_INL_H_ + +#include +#include +#include +#include + +#include "base/kaldi-utils.h" +#include "util/kaldi-io.h" +#include "util/text-utils.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// \addtogroup holders +/// @{ + + +// KaldiObjectHolder is valid only for Kaldi objects with +// copy constructors, default constructors, and "normal" +// Kaldi Write and Read functions. E.g. it works for +// Matrix and Vector. +template class KaldiObjectHolder { + public: + typedef KaldiType T; + + KaldiObjectHolder(): t_(NULL) { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + t.Write(os, binary); + return os.good(); + } catch(const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object. " << e.what(); + return false; // Write failure. + } + } + + void Clear() { + if (t_) { + delete t_; + t_ = NULL; + } + } + + // Reads into the holder. + bool Read(std::istream &is) { + delete t_; + t_ = new T; + // Don't want any existing state to complicate the read function: get new + // object. + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object, failed reading binary header\n"; + return false; + } + try { + t_->Read(is, is_binary); + return true; + } catch(const std::exception &e) { + KALDI_WARN << "Exception caught reading Table object. " << e.what(); + delete t_; + t_ = NULL; + return false; + } + } + + // Kaldi objects always have the stream open in binary mode for + // reading. + static bool IsReadInBinary() { return true; } + + T &Value() { + // code error if !t_. + if (!t_) KALDI_ERR << "KaldiObjectHolder::Value() called wrongly."; + return *t_; + } + + void Swap(KaldiObjectHolder *other) { + // the t_ values are pointers so this is a shallow swap. + std::swap(t_, other->t_); + } + + bool ExtractRange(const KaldiObjectHolder &other, + const std::string &range) { + KALDI_ASSERT(other.t_ != NULL); + delete t_; + t_ = new T; + // this call will fail for most object types. + return ExtractObjectRange(*(other.t_), range, t_); + } + + ~KaldiObjectHolder() { delete t_; } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(KaldiObjectHolder); + T *t_; +}; + + +// BasicHolder is valid for float, double, bool, and integer +// types. There will be a compile time error otherwise, because +// we make sure that the {Write, Read}BasicType functions do not +// get instantiated for other types. + +template class BasicHolder { + public: + typedef BasicType T; + + BasicHolder(): t_(static_cast(-1)) { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + WriteBasicType(os, binary, t); + if (!binary) os << '\n'; // Makes output format more readable and + // easier to manipulate. + return os.good(); + } catch(const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object. " << e.what(); + return false; // Write failure. + } + } + + void Clear() { } + + // Reads into the holder. + bool Read(std::istream &is) { + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object [integer type], failed reading binary" + " header\n"; + return false; + } + try { + int c; + if (!is_binary) { // This is to catch errors, the class would work + // without it.. + // Eat up any whitespace and make sure it's not newline. + while (isspace((c = is.peek())) && c != static_cast('\n')) { + is.get(); + } + if (is.peek() == '\n') { + KALDI_WARN << "Found newline but expected basic type."; + return false; // This is just to catch a more- + // likely-than average type of error (empty line before the token), + // since ReadBasicType will eat it up. + } + } + + ReadBasicType(is, is_binary, &t_); + + if (!is_binary) { // This is to catch errors, the class would work + // without it.. + // make sure there is a newline. + while (isspace((c = is.peek())) && c != static_cast('\n')) { + is.get(); + } + if (is.peek() != '\n') { + KALDI_WARN << "BasicHolder::Read, expected newline, got " + << CharToString(is.peek()) << ", position " << is.tellg(); + return false; + } + is.get(); // Consume the newline. + } + return true; + } catch(const std::exception &e) { + KALDI_WARN << "Exception caught reading Table object. " << e.what(); + return false; + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + T &Value() { + return t_; + } + + void Swap(BasicHolder *other) { + std::swap(t_, other->t_); + } + + bool ExtractRange(const BasicHolder &other, const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + ~BasicHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicHolder); + + T t_; +}; + + +/// A Holder for a vector of basic types, e.g. +/// std::vector, std::vector, and so on. +/// Note: a basic type is defined as a type for which ReadBasicType +/// and WriteBasicType are implemented, i.e. integer and floating +/// types, and bool. +template class BasicVectorHolder { + public: + typedef std::vector T; + + BasicVectorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + if (binary) { // need to write the size, in binary mode. + KALDI_ASSERT(static_cast(static_cast(t.size())) == + t.size()); + // Or this Write routine cannot handle such a large vector. + // use int32 because it's fixed size regardless of compilation. + // change to int64 (plus in Read function) if this becomes a problem. + WriteBasicType(os, binary, static_cast(t.size())); + for (typename std::vector::const_iterator iter = t.begin(); + iter != t.end(); ++iter) + WriteBasicType(os, binary, *iter); + + } else { + for (typename std::vector::const_iterator iter = t.begin(); + iter != t.end(); ++iter) + WriteBasicType(os, binary, *iter); + os << '\n'; // Makes output format more readable and + // easier to manipulate. In text mode, this function writes something + // like "1 2 3\n". + } + return os.good(); + } catch(const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object (BasicVector). " + << e.what(); + return false; // Write failure. + } + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object [integer type], failed reading binary" + " header\n"; + return false; + } + if (!is_binary) { + // In text mode, we terminate with newline. + std::string line; + getline(is, line); // this will discard the \n, if present. + if (is.fail()) { + KALDI_WARN << "BasicVectorHolder::Read, error reading line " << + (is.eof() ? "[eof]" : ""); + return false; // probably eof. fail in any case. + } + std::istringstream line_is(line); + try { + while (1) { + line_is >> std::ws; // eat up whitespace. + if (line_is.eof()) break; + BasicType bt; + ReadBasicType(line_is, false, &bt); + t_.push_back(bt); + } + return true; + } catch(const std::exception &e) { + KALDI_WARN << "BasicVectorHolder::Read, could not interpret line: " + << "'" << line << "'" << "\n" << e.what(); + return false; + } + } else { // binary mode. + size_t filepos = is.tellg(); + try { + int32 size; + ReadBasicType(is, true, &size); + t_.resize(size); + for (typename std::vector::iterator iter = t_.begin(); + iter != t_.end(); + ++iter) { + ReadBasicType(is, true, &(*iter)); + } + return true; + } catch(...) { + KALDI_WARN << "BasicVectorHolder::Read, read error or unexpected data" + " at archive entry beginning at file position " << filepos; + return false; + } + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + T &Value() { return t_; } + + void Swap(BasicVectorHolder *other) { + t_.swap(other->t_); + } + + bool ExtractRange(const BasicVectorHolder &other, + const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + ~BasicVectorHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicVectorHolder); + T t_; +}; + + +/// BasicVectorVectorHolder is a Holder for a vector of vector of +/// a basic type, e.g. std::vector >. +/// Note: a basic type is defined as a type for which ReadBasicType +/// and WriteBasicType are implemented, i.e. integer and floating +/// types, and bool. +template class BasicVectorVectorHolder { + public: + typedef std::vector > T; + + BasicVectorVectorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + if (binary) { // need to write the size, in binary mode. + KALDI_ASSERT(static_cast(static_cast(t.size())) == + t.size()); + // Or this Write routine cannot handle such a large vector. + // use int32 because it's fixed size regardless of compilation. + // change to int64 (plus in Read function) if this becomes a problem. + WriteBasicType(os, binary, static_cast(t.size())); + for (typename std::vector >::const_iterator + iter = t.begin(); + iter != t.end(); ++iter) { + KALDI_ASSERT(static_cast(static_cast(iter->size())) + == iter->size()); + WriteBasicType(os, binary, static_cast(iter->size())); + for (typename std::vector::const_iterator + iter2 = iter->begin(); + iter2 != iter->end(); ++iter2) { + WriteBasicType(os, binary, *iter2); + } + } + } else { // text mode... + // In text mode, we write out something like (for integers): + // "1 2 3 ; 4 5 ; 6 ; ; 7 8 9 ;\n" + // where the semicolon is a terminator, not a separator + // (a separator would cause ambiguity between an + // empty list, and a list containing a single empty list). + for (typename std::vector >::const_iterator + iter = t.begin(); + iter != t.end(); + ++iter) { + for (typename std::vector::const_iterator + iter2 = iter->begin(); + iter2 != iter->end(); ++iter2) + WriteBasicType(os, binary, *iter2); + os << "; "; + } + os << '\n'; + } + return os.good(); + } catch(const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object. " << e.what(); + return false; // Write failure. + } + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Failed reading binary header\n"; + return false; + } + if (!is_binary) { + // In text mode, we terminate with newline. + try { // catching errors from ReadBasicType.. + std::vector v; // temporary vector + while (1) { + int i = is.peek(); + if (i == -1) { + KALDI_WARN << "Unexpected EOF"; + return false; + } else if (static_cast(i) == '\n') { + if (!v.empty()) { + KALDI_WARN << "No semicolon before newline (wrong format)"; + return false; + } else { + is.get(); + return true; + } + } else if (std::isspace(i)) { + is.get(); + } else if (static_cast(i) == ';') { + t_.push_back(v); + v.clear(); + is.get(); + } else { // some object we want to read... + BasicType b; + ReadBasicType(is, false, &b); // throws on error. + v.push_back(b); + } + } + } catch(const std::exception &e) { + KALDI_WARN << "BasicVectorVectorHolder::Read, read error. " << e.what(); + return false; + } + } else { // binary mode. + size_t filepos = is.tellg(); + try { + int32 size; + ReadBasicType(is, true, &size); + t_.resize(size); + for (typename std::vector >::iterator + iter = t_.begin(); + iter != t_.end(); + ++iter) { + int32 size2; + ReadBasicType(is, true, &size2); + iter->resize(size2); + for (typename std::vector::iterator iter2 = iter->begin(); + iter2 != iter->end(); + ++iter2) + ReadBasicType(is, true, &(*iter2)); + } + return true; + } catch(...) { + KALDI_WARN << "Read error or unexpected data at archive entry beginning" + " at file position " << filepos; + return false; + } + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + T &Value() { return t_; } + + void Swap(BasicVectorVectorHolder *other) { + t_.swap(other->t_); + } + + bool ExtractRange(BasicVectorVectorHolder &other, + const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + ~BasicVectorVectorHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicVectorVectorHolder); + T t_; +}; + + +/// BasicPairVectorHolder is a Holder for a vector of pairs of +/// a basic type, e.g. std::vector >. +/// Note: a basic type is defined as a type for which ReadBasicType +/// and WriteBasicType are implemented, i.e. integer and floating +/// types, and bool. +template class BasicPairVectorHolder { + public: + typedef std::vector > T; + + BasicPairVectorHolder() { } + + static bool Write(std::ostream &os, bool binary, const T &t) { + InitKaldiOutputStream(os, binary); // Puts binary header if binary mode. + try { + if (binary) { // need to write the size, in binary mode. + KALDI_ASSERT(static_cast(static_cast(t.size())) == + t.size()); + // Or this Write routine cannot handle such a large vector. + // use int32 because it's fixed size regardless of compilation. + // change to int64 (plus in Read function) if this becomes a problem. + WriteBasicType(os, binary, static_cast(t.size())); + for (typename T::const_iterator iter = t.begin(); + iter != t.end(); ++iter) { + WriteBasicType(os, binary, iter->first); + WriteBasicType(os, binary, iter->second); + } + } else { // text mode... + // In text mode, we write out something like (for integers): + // "1 2 ; 4 5 ; 6 7 ; 8 9 \n" + // where the semicolon is a separator, not a terminator. + for (typename T::const_iterator iter = t.begin(); + iter != t.end();) { + WriteBasicType(os, binary, iter->first); + WriteBasicType(os, binary, iter->second); + ++iter; + if (iter != t.end()) + os << "; "; + } + os << '\n'; + } + return os.good(); + } catch(const std::exception &e) { + KALDI_WARN << "Exception caught writing Table object. " << e.what(); + return false; // Write failure. + } + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + bool is_binary; + if (!InitKaldiInputStream(is, &is_binary)) { + KALDI_WARN << "Reading Table object [integer type], failed reading binary" + " header\n"; + return false; + } + if (!is_binary) { + // In text mode, we terminate with newline. + try { // catching errors from ReadBasicType.. + std::vector v; // temporary vector + while (1) { + int i = is.peek(); + if (i == -1) { + KALDI_WARN << "Unexpected EOF"; + return false; + } else if (static_cast(i) == '\n') { + if (t_.empty() && v.empty()) { + is.get(); + return true; + } else if (v.size() == 2) { + t_.push_back(std::make_pair(v[0], v[1])); + is.get(); + return true; + } else { + KALDI_WARN << "Unexpected newline, reading vector >; got " + << v.size() << " elements, expected 2."; + return false; + } + } else if (std::isspace(i)) { + is.get(); + } else if (static_cast(i) == ';') { + if (v.size() != 2) { + KALDI_WARN << "Wrong input format, reading vector >; got " + << v.size() << " elements, expected 2."; + return false; + } + t_.push_back(std::make_pair(v[0], v[1])); + v.clear(); + is.get(); + } else { // some object we want to read... + BasicType b; + ReadBasicType(is, false, &b); // throws on error. + v.push_back(b); + } + } + } catch(const std::exception &e) { + KALDI_WARN << "BasicPairVectorHolder::Read, read error. " << e.what(); + return false; + } + } else { // binary mode. + size_t filepos = is.tellg(); + try { + int32 size; + ReadBasicType(is, true, &size); + t_.resize(size); + for (typename T::iterator iter = t_.begin(); + iter != t_.end(); + ++iter) { + ReadBasicType(is, true, &(iter->first)); + ReadBasicType(is, true, &(iter->second)); + } + return true; + } catch(...) { + KALDI_WARN << "BasicVectorHolder::Read, read error or unexpected data" + " at archive entry beginning at file position " << filepos; + return false; + } + } + } + + // Objects read/written with the Kaldi I/O functions always have the stream + // open in binary mode for reading. + static bool IsReadInBinary() { return true; } + + T &Value() { return t_; } + + void Swap(BasicPairVectorHolder *other) { + t_.swap(other->t_); + } + + bool ExtractRange(const BasicPairVectorHolder &other, + const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + ~BasicPairVectorHolder() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(BasicPairVectorHolder); + T t_; +}; + + + + +// We define a Token as a nonempty, printable, whitespace-free std::string. +// The binary and text formats here are the same (newline-terminated) +// and as such we don't bother with the binary-mode headers. +class TokenHolder { + public: + typedef std::string T; + + TokenHolder() {} + + static bool Write(std::ostream &os, bool, const T &t) { // ignore binary-mode + KALDI_ASSERT(IsToken(t)); + os << t << '\n'; + return os.good(); + } + + void Clear() { t_.clear(); } + + // Reads into the holder. + bool Read(std::istream &is) { + is >> t_; + if (is.fail()) return false; + char c; + while (isspace(c = is.peek()) && c!= '\n') is.get(); + if (is.peek() != '\n') { + KALDI_WARN << "TokenHolder::Read, expected newline, got char " + << CharToString(is.peek()) + << ", at stream pos " << is.tellg(); + return false; + } + is.get(); // get '\n' + return true; + } + + + // Since this is fundamentally a text format, read in text mode (would work + // fine either way, but doing it this way will exercise more of the code). + static bool IsReadInBinary() { return false; } + + T &Value() { return t_; } + + ~TokenHolder() { } + + void Swap(TokenHolder *other) { + t_.swap(other->t_); + } + + bool ExtractRange(const TokenHolder &other, + const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(TokenHolder); + T t_; +}; + +// A Token is a nonempty, whitespace-free std::string. +// Class TokenVectorHolder is a Holder class for vectors of these. +class TokenVectorHolder { + public: + typedef std::vector T; + + TokenVectorHolder() { } + + static bool Write(std::ostream &os, bool, const T &t) { // ignore binary-mode + for (std::vector::const_iterator iter = t.begin(); + iter != t.end(); + ++iter) { + KALDI_ASSERT(IsToken(*iter)); // make sure it's whitespace-free, + // printable and nonempty. + os << *iter << ' '; + } + os << '\n'; + return os.good(); + } + + void Clear() { t_.clear(); } + + + // Reads into the holder. + bool Read(std::istream &is) { + t_.clear(); + + // there is no binary/non-binary mode. + + std::string line; + getline(is, line); // this will discard the \n, if present. + if (is.fail()) { + KALDI_WARN << "BasicVectorHolder::Read, error reading line " << (is.eof() + ? "[eof]" : ""); + return false; // probably eof. fail in any case. + } + const char *white_chars = " \t\n\r\f\v"; + SplitStringToVector(line, white_chars, true, &t_); // true== omit + // empty strings e.g. between spaces. + return true; + } + + // Read in text format since it's basically a text-mode thing.. doesn't really + // matter, it would work either way since we ignore the extra '\r'. + static bool IsReadInBinary() { return false; } + + T &Value() { return t_; } + + void Swap(TokenVectorHolder *other) { + t_.swap(other->t_); + } + + bool ExtractRange(const TokenVectorHolder &other, + const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(TokenVectorHolder); + T t_; +}; + + +class HtkMatrixHolder { + public: + typedef std::pair, HtkHeader> T; + + HtkMatrixHolder() {} + + static bool Write(std::ostream &os, bool binary, const T &t) { + if (!binary) + KALDI_ERR << "Non-binary HTK-format write not supported."; + bool ans = WriteHtk(os, t.first, t.second); + if (!ans) + KALDI_WARN << "Error detected writing HTK-format matrix."; + return ans; + } + + void Clear() { t_.first.Resize(0, 0); } + + // Reads into the holder. + bool Read(std::istream &is) { + bool ans = ReadHtk(is, &t_.first, &t_.second); + if (!ans) { + KALDI_WARN << "Error detected reading HTK-format matrix."; + return false; + } + return ans; + } + + // HTK-format matrices only read in binary. + static bool IsReadInBinary() { return true; } + + T &Value() { return t_; } + + void Swap(HtkMatrixHolder *other) { + t_.first.Swap(&(other->t_.first)); + std::swap(t_.second, other->t_.second); + } + + bool ExtractRange(const HtkMatrixHolder &other, + const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + // Default destructor. + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(HtkMatrixHolder); + T t_; +}; + +// SphinxMatrixHolder can be used to read and write feature files in +// CMU Sphinx format. 13-dimensional big-endian features are assumed. +// The ultimate reference is SphinxBase's source code (for example see +// feat_s2mfc_read() in src/libsphinxbase/feat/feat.c). +// We can't fully automate the detection of machine/feature file endianess +// mismatch here, because for this Sphinx relies on comparing the feature +// file's size with the number recorded in its header. We are working with +// streams, however(what happens if this is a Kaldi archive?). This should +// be no problem, because the usage help of Sphinx' "wave2feat" for example +// says that Sphinx features are always big endian. +// Note: the kFeatDim defaults to 13, see forward declaration in kaldi-holder.h +template class SphinxMatrixHolder { + public: + typedef Matrix T; + + SphinxMatrixHolder() {} + + void Clear() { feats_.Resize(0, 0); } + + // Writes Sphinx-format features + static bool Write(std::ostream &os, bool binary, const T &m) { + if (!binary) { + KALDI_WARN << "SphinxMatrixHolder can't write Sphinx features in text "; + return false; + } + + int32 size = m.NumRows() * m.NumCols(); + if (MachineIsLittleEndian()) + KALDI_SWAP4(size); + // write the header + os.write(reinterpret_cast (&size), sizeof(size)); + + for (MatrixIndexT i = 0; i < m.NumRows(); i++) { + std::vector tmp(m.NumCols()); + for (MatrixIndexT j = 0; j < m.NumCols(); j++) { + tmp[j] = static_cast(m(i, j)); + if (MachineIsLittleEndian()) + KALDI_SWAP4(tmp[j]); + } + os.write(reinterpret_cast(&(tmp[0])), + tmp.size() * 4); + } + return true; + } + + // Reads the features into a Kaldi Matrix + bool Read(std::istream &is) { + int32 nmfcc; + + is.read(reinterpret_cast (&nmfcc), sizeof(nmfcc)); + if (MachineIsLittleEndian()) + KALDI_SWAP4(nmfcc); + KALDI_VLOG(2) << "#feats: " << nmfcc; + int32 nfvec = nmfcc / kFeatDim; + if ((nmfcc % kFeatDim) != 0) { + KALDI_WARN << "Sphinx feature count is inconsistent with vector length "; + return false; + } + + feats_.Resize(nfvec, kFeatDim); + for (MatrixIndexT i = 0; i < feats_.NumRows(); i++) { + if (sizeof(BaseFloat) == sizeof(float32)) { + is.read(reinterpret_cast (feats_.RowData(i)), + kFeatDim * sizeof(float32)); + if (!is.good()) { + KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + return false; + } + if (MachineIsLittleEndian()) { + for (MatrixIndexT j = 0; j < kFeatDim; j++) + KALDI_SWAP4(feats_(i, j)); + } + } else { // KALDI_DOUBLEPRECISION=1 + float32 tmp[kFeatDim]; + is.read(reinterpret_cast (tmp), sizeof(tmp)); + if (!is.good()) { + KALDI_WARN << "Unexpected error/EOF while reading Sphinx features "; + return false; + } + for (MatrixIndexT j = 0; j < kFeatDim; j++) { + if (MachineIsLittleEndian()) + KALDI_SWAP4(tmp[j]); + feats_(i, j) = static_cast(tmp[j]); + } + } + } + + return true; + } + + // Only read in binary + static bool IsReadInBinary() { return true; } + + T &Value() { return feats_; } + + void Swap(SphinxMatrixHolder *other) { + feats_.Swap(&(other->feats_)); + } + + bool ExtractRange(const SphinxMatrixHolder &other, + const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(SphinxMatrixHolder); + T feats_; +}; + + +/// @} end "addtogroup holders" + +} // end namespace kaldi + + + +#endif // KALDI_UTIL_KALDI_HOLDER_INL_H_ diff --git a/speechx/speechx/kaldi/util/kaldi-holder.cc b/speechx/speechx/kaldi/util/kaldi-holder.cc new file mode 100644 index 00000000..577679ef --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-holder.cc @@ -0,0 +1,229 @@ +// util/kaldi-holder.cc + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "util/kaldi-holder.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +// Parse matrix range specifier in form r1:r2,c1:c2 +// where any of those four numbers can be missing. In those +// cases, the missing number is set either to 0 (for r1 or c1) +// or the value of parameter rows -1 or columns -1 (which +// represent the dimensions of the original matrix) for missing +// r2 or c2, respectively. +// Examples of valid ranges: 0:39,: or :,:3 or :,5:10 +bool ParseMatrixRangeSpecifier(const std::string &range, + const int rows, const int cols, + std::vector *row_range, + std::vector *col_range) { + if (range.empty()) { + KALDI_ERR << "Empty range specifier."; + return false; + } + std::vector splits; + SplitStringToVector(range, ",", false, &splits); + if (!((splits.size() == 1 && !splits[0].empty()) || + (splits.size() == 2 && !splits[0].empty() && !splits[1].empty()))) { + KALDI_ERR << "Invalid range specifier for matrix: " << range; + return false; + } + + bool status = true; + + if (splits[0] != ":") + status = SplitStringToIntegers(splits[0], ":", false, row_range); + + if (splits.size() == 2 && splits[1] != ":") { + status = status && SplitStringToIntegers(splits[1], ":", false, col_range); + } + if (row_range->size() == 0) { + row_range->push_back(0); + row_range->push_back(rows - 1); + } + if (col_range->size() == 0) { + col_range->push_back(0); + col_range->push_back(cols - 1); + } + + // Length tolerance of 3 -- 2 to account for edge effects when + // frame-length is 25ms and frame-shift is 10ms, and 1 for rounding effects + // since segments are usually retained up to 2 decimal places. + int32 length_tolerance = 3; + if (!(status && row_range->size() == 2 && col_range->size() == 2 && + row_range->at(0) >= 0 && row_range->at(0) <= row_range->at(1) && + row_range->at(1) < rows + length_tolerance && + col_range->at(0) >=0 && + col_range->at(0) <= col_range->at(1) && col_range->at(1) < cols)) { + KALDI_ERR << "Invalid range specifier: " << range + << " for matrix of size " << rows + << "x" << cols; + return false; + } + + if (row_range->at(1) >= rows) + KALDI_WARN << "Row range " << row_range->at(0) << ":" << row_range->at(1) + << " goes beyond the number of rows of the " + << "matrix " << rows; + return status; +} + +bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, + GeneralMatrix *output) { + // We just inspect input's type and forward to the correct implementation + // if available. For kSparseMatrix, we do just fairly inefficient conversion + // to a full matrix. + Matrix output_mat; + if (input.Type() == kFullMatrix) { + const Matrix &in = input.GetFullMatrix(); + ExtractObjectRange(in, range, &output_mat); + } else if (input.Type() == kCompressedMatrix) { + const CompressedMatrix &in = input.GetCompressedMatrix(); + ExtractObjectRange(in, range, &output_mat); + } else { + KALDI_ASSERT(input.Type() == kSparseMatrix); + // NOTE: this is fairly inefficient, so if this happens to be bottleneck + // it should be re-implemented more efficiently. + Matrix input_mat; + input.GetMatrix(&input_mat); + ExtractObjectRange(input_mat, range, &output_mat); + } + output->Clear(); + output->SwapFullMatrix(&output_mat); + return true; +} + +template +bool ExtractObjectRange(const CompressedMatrix &input, const std::string &range, + Matrix *output) { + std::vector row_range, col_range; + + if (!ParseMatrixRangeSpecifier(range, input.NumRows(), input.NumCols(), + &row_range, &col_range)) { + KALDI_ERR << "Could not parse range specifier \"" << range << "\"."; + } + + int32 row_size = std::min(row_range[1], input.NumRows() - 1) + - row_range[0] + 1, + col_size = col_range[1] - col_range[0] + 1; + + output->Resize(row_size, col_size, kUndefined); + input.CopyToMat(row_range[0], col_range[0], output); + return true; +} + +// template instantiation +template bool ExtractObjectRange(const CompressedMatrix &, const std::string &, + Matrix *); +template bool ExtractObjectRange(const CompressedMatrix &, const std::string &, + Matrix *); + +template +bool ExtractObjectRange(const Matrix &input, const std::string &range, + Matrix *output) { + std::vector row_range, col_range; + + if (!ParseMatrixRangeSpecifier(range, input.NumRows(), input.NumCols(), + &row_range, &col_range)) { + KALDI_ERR << "Could not parse range specifier \"" << range << "\"."; + } + + int32 row_size = std::min(row_range[1], input.NumRows() - 1) + - row_range[0] + 1, + col_size = col_range[1] - col_range[0] + 1; + output->Resize(row_size, col_size, kUndefined); + output->CopyFromMat(input.Range(row_range[0], row_size, + col_range[0], col_size)); + return true; +} + +// template instantiation +template bool ExtractObjectRange(const Matrix &, const std::string &, + Matrix *); +template bool ExtractObjectRange(const Matrix &, const std::string &, + Matrix *); + +template +bool ExtractObjectRange(const Vector &input, const std::string &range, + Vector *output) { + if (range.empty()) { + KALDI_ERR << "Empty range specifier."; + return false; + } + std::vector splits; + SplitStringToVector(range, ",", false, &splits); + if (!((splits.size() == 1 && !splits[0].empty()))) { + KALDI_ERR << "Invalid range specifier for vector: " << range; + return false; + } + std::vector index_range; + bool status = true; + if (splits[0] != ":") + status = SplitStringToIntegers(splits[0], ":", false, &index_range); + + if (index_range.size() == 0) { + index_range.push_back(0); + index_range.push_back(input.Dim() - 1); + } + + // Length tolerance of 3 -- 2 to account for edge effects when + // frame-length is 25ms and frame-shift is 10ms, and 1 for rounding effects + // since segments are usually retained up to 2 decimal places. + int32 length_tolerance = 3; + if (!(status && index_range.size() == 2 && + index_range[0] >= 0 && index_range[0] <= index_range[1] && + index_range[1] < input.Dim() + length_tolerance)) { + KALDI_ERR << "Invalid range specifier: " << range + << " for vector of size " << input.Dim(); + return false; + } + + if (index_range[1] >= input.Dim()) + KALDI_WARN << "Range " << index_range[0] << ":" << index_range[1] + << " goes beyond the vector dimension " << input.Dim(); + int32 size = std::min(index_range[1], input.Dim() - 1) - index_range[0] + 1; + output->Resize(size, kUndefined); + output->CopyFromVec(input.Range(index_range[0], size)); + return true; +} + +// template instantiation +template bool ExtractObjectRange(const Vector &, const std::string &, + Vector *); +template bool ExtractObjectRange(const Vector &, const std::string &, + Vector *); + +bool ExtractRangeSpecifier(const std::string &rxfilename_with_range, + std::string *data_rxfilename, + std::string *range) { + if (rxfilename_with_range.empty() || + rxfilename_with_range[rxfilename_with_range.size()-1] != ']') + KALDI_ERR << "ExtractRangeRspecifier called wrongly."; + std::vector splits; + SplitStringToVector(rxfilename_with_range, "[", false, &splits); + if (splits.size() == 2 && !splits[0].empty() && splits[1].size() > 1) { + *data_rxfilename = splits[0]; + range->assign(splits[1], 0, splits[1].size()-1); + return true; + } + return false; +} + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/util/kaldi-holder.h b/speechx/speechx/kaldi/util/kaldi-holder.h new file mode 100644 index 00000000..f495f27f --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-holder.h @@ -0,0 +1,282 @@ +// util/kaldi-holder.h + +// Copyright 2009-2011 Microsoft Corporation +// 2016 Johns Hopkins University (author: Daniel Povey) +// 2016 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_KALDI_HOLDER_H_ +#define KALDI_UTIL_KALDI_HOLDER_H_ + +#include +#include "util/kaldi-io.h" +#include "util/text-utils.h" +#include "matrix/kaldi-vector.h" +#include "matrix/sparse-matrix.h" + +namespace kaldi { + + +// The Table class uses a Holder class to wrap objects, and make them behave +// in a "normalized" way w.r.t. reading and writing, so the Table class can +// be template-ized without too much trouble. Look below this +// comment (search for GenericHolder) to see what it looks like. +// +// Requirements of the holder class: +// +// They can only contain objects that can be read/written without external +// information; other objects cannot be stored in this type of archive. +// +// In terms of what functions it should have, see GenericHolder below. +// It is just for documentation. +// +// (1) Requirements of the Read and Write functions +// +// The Read and Write functions should have the property that in a longer +// file, if the Read function is started from where the Write function started +// writing, it should go to where the Write function stopped writing, in either +// text or binary mode (but it's OK if it doesn't eat up trailing space). +// +// [Desirable property: when writing in text mode the output should contain +// exactly one newline, at the end of the output; this makes it easier to +// manipulate] +// +// [Desirable property for classes: the output should just be a binary-mode +// header (if in binary mode and it's a Kaldi object, or no header +// othewise), and then the output of Object.Write(). This means that when +// written to individual files with the scp: type of wspecifier, we can +// read the individual files in the "normal" Kaldi way by reading the +// binary header and then the object.] +// +// +// The Write function takes a 'binary' argument. In general, each object will +// have two formats: text and binary. However, it's permitted to throw() if +// asked to read in the text format if there is none. The file will be open, if +// the file system has binary/text modes, in the corresponding mode. However, +// the object should have a file-mode in which it can read either text or binary +// output. It announces this via the static IsReadInBinary() function. This +// will generally be the binary mode and it means that where necessary, in text +// formats, we must ignore \r characters. +// +// Memory requirements: if it allocates memory, the destructor should +// free that memory. Copying and assignment of Holder objects may be +// disallowed as the Table code never does this. + + +/// GenericHolder serves to document the requirements of the Holder interface; +/// it's not intended to be used. +template class GenericHolder { + public: + typedef SomeType T; + + /// Must have a constructor that takes no arguments. + GenericHolder() { } + + /// Write() writes this object of type T. Possibly also writes a binary-mode + /// header so that the Read function knows which mode to read in (since the + /// Read function does not get this information). It's a static member so we + /// can write those not inside this class (can use this function with Value() + /// to write from this class). The Write method may throw if it cannot write + /// the object in the given (binary/non-binary) mode. The holder object can + /// assume the stream has been opened in the given mode (where relevant). The + /// object can write the data how it likes. + static bool Write(std::ostream &os, bool binary, const T &t); + + /// Reads into the holder. Must work out from the stream (which will be + /// opened on Windows in binary mode if the IsReadInBinary() function of this + /// class returns true, and text mode otherwise) whether the actual data is + /// binary or not (usually via reading the Kaldi binary-mode header). + /// We put the responsibility for reading the Kaldi binary-mode header in the + /// Read function (rather than making the binary mode an argument to this + /// function), so that for non-Kaldi binary files we don't have to write the + /// header, which would prevent the file being read by non-Kaldi programs + /// (e.g. if we write to individual files using an scp). + /// Read must deallocate any existing data we have here, if applicable (must + /// not assume the object was newly constructed). + /// Returns true on success. + /// If Read() returns false, the contents of this object and hence the value + /// returned by Value() may be undefined. + bool Read(std::istream &is); + + /// IsReadInBinary() will return true if the object wants the file to be + /// opened in binary for reading (if the file system has binary/text modes), + /// and false otherwise. Static function. Kaldi objects always return true + /// as they always read in binary mode. Note that we must be able to read, in + /// this mode, objects written in both text and binary mode by Write (which + /// may mean ignoring "\r" characters). I doubt we will ever want this + /// function to return false. + static bool IsReadInBinary() { return true; } + + /// Returns the value of the object held here. Will only + /// ever be called if Read() has been previously called and it returned + /// true (so OK to throw exception if no object was read). + T &Value() { return t_; } // if t is a pointer, would return *t_; + + /// The Clear() function doesn't have to do anything. Its purpose is to + /// allow the object to free resources if they're no longer needed. + void Clear() { } + + /// This swaps the objects held by *this and *other (preferably a shallow + /// swap). Note, this is just an example. The swap is with the *same type* + /// of holder, not with some nonexistent base-class (remember, GenericHolder is + /// an example for documentation, not a base-class). + void Swap(GenericHolder *other) { std::swap(t_, other->t_); } + + /// At the time of writing this will only do something meaningful + /// KaldiObjectHolder holding matrix objects, in order to extract a holder + /// holding a sub-matrix specified by 'range', e.g. [0:3,2:10], like in Matlab + /// but with zero-based indexing. It returns true with successful extraction + /// of the range, false if the range was invalid or outside the bounds of the + /// matrix. For other types of holder it just throws an error. + bool ExtractRange(const GenericHolder &other, const std::string &range) { + KALDI_ERR << "ExtractRange is not defined for this type of holder."; + return false; + } + + /// If the object held pointers, the destructor would free them. + ~GenericHolder() { } + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(GenericHolder); + T t_; // t_ may alternatively be of type T*. +}; + + +// See kaldi-holder-inl.h for examples of some actual Holder +// classes and templates. + + +// The following two typedefs should probably be in their own file, but they're +// here until there are enough of them to warrant their own header. + + +/// \addtogroup holders +/// @{ + +/// KaldiObjectHolder works for Kaldi objects that have the "standard" Read +/// and Write functions, and a copy constructor. +template class KaldiObjectHolder; + +/// BasicHolder is valid for float, double, bool, and integer +/// types. There will be a compile time error otherwise, because +/// we make sure that the {Write, Read}BasicType functions do not +/// get instantiated for other types. +template class BasicHolder; + + +// A Holder for a vector of basic types, e.g. +// std::vector, std::vector, and so on. +// Note: a basic type is defined as a type for which ReadBasicType +// and WriteBasicType are implemented, i.e. integer and floating +// types, and bool. +template class BasicVectorHolder; + + +// A holder for vectors of vectors of basic types, e.g. +// std::vector >, and so on. +// Note: a basic type is defined as a type for which ReadBasicType +// and WriteBasicType are implemented, i.e. integer and floating +// types, and bool. +template class BasicVectorVectorHolder; + +// A holder for vectors of pairs of basic types, e.g. +// std::vector >, and so on. +// Note: a basic type is defined as a type for which ReadBasicType +// and WriteBasicType are implemented, i.e. integer and floating +// types, and bool. Text format is (e.g. for integers), +// "1 12 ; 43 61 ; 17 8 \n" +template class BasicPairVectorHolder; + +/// We define a Token (not a typedef, just a word) as a nonempty, printable, +/// whitespace-free std::string. The binary and text formats here are the same +/// (newline-terminated) and as such we don't bother with the binary-mode +/// headers. +class TokenHolder; + +/// Class TokenVectorHolder is a Holder class for vectors of Tokens +/// (T == std::string). +class TokenVectorHolder; + +/// A class for reading/writing HTK-format matrices. +/// T == std::pair, HtkHeader> +class HtkMatrixHolder; + +/// A class for reading/writing Sphinx format matrices. +template class SphinxMatrixHolder; + +/// This templated function exists so that we can write .scp files with +/// 'object ranges' specified: the canonical example is a [first:last] range +/// of rows of a matrix, or [first-row:last-row,first-column,last-column] +/// of a matrix. We can also support [begin-time:end-time] of a wave +/// file. The string 'range' is whatever is in the square brackets; it is +/// parsed inside this function. +/// This function returns true if the partial object was successfully extracted, +/// and false if there was an error such as an invalid range. +/// The generic version of this function just fails; we overload the template +/// whenever we need it for a specific class. +template +bool ExtractObjectRange(const T &input, const std::string &range, T *output) { + KALDI_ERR << "Ranges not supported for objects of this type."; + return false; +} + +/// The template is specialized with a version that actually does something, +/// for types Matrix and Matrix. We can later add versions of +/// this template for other types, such as Vector, which can meaningfully +/// have ranges extracted. +template +bool ExtractObjectRange(const Matrix &input, const std::string &range, + Matrix *output); + +/// The template is specialized types Vector and Vector. +template +bool ExtractObjectRange(const Vector &input, const std::string &range, + Vector *output); + +/// GeneralMatrix is always of type BaseFloat +bool ExtractObjectRange(const GeneralMatrix &input, const std::string &range, + GeneralMatrix *output); + +/// CompressedMatrix is always of the type BaseFloat but it is more +/// efficient to provide template as it uses CompressedMatrix's own +/// conversion to Matrix +template +bool ExtractObjectRange(const CompressedMatrix &input, const std::string &range, + Matrix *output); + +// In SequentialTableReaderScriptImpl and RandomAccessTableReaderScriptImpl, for +// cases where the scp contained 'range specifiers' (things in square brackets +// identifying parts of objects like matrices), use this function to separate +// the input string 'rxfilename_with_range' (e.g "1.ark:100[1:2,2:10]") into the data_rxfilename +// (e.g. "1.ark:100") and the optional range specifier which will be everything +// inside the square brackets. It returns true if everything seems OK, and +// false if for example the string contained more than one '['. This function +// should only be called if 'line' ends in ']', otherwise it is an error. +bool ExtractRangeSpecifier(const std::string &rxfilename_with_range, + std::string *data_rxfilename, + std::string *range); + + +/// @} end "addtogroup holders" + + +} // end namespace kaldi + +#include "util/kaldi-holder-inl.h" + +#endif // KALDI_UTIL_KALDI_HOLDER_H_ diff --git a/speechx/speechx/kaldi/util/kaldi-io-inl.h b/speechx/speechx/kaldi/util/kaldi-io-inl.h new file mode 100644 index 00000000..2474f701 --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-io-inl.h @@ -0,0 +1,46 @@ +// util/kaldi-io-inl.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_KALDI_IO_INL_H_ +#define KALDI_UTIL_KALDI_IO_INL_H_ + +#include + +namespace kaldi { + +bool Input::Open(const std::string &rxfilename, bool *binary) { + return OpenInternal(rxfilename, true, binary); +} + +bool Input::OpenTextMode(const std::string &rxfilename) { + return OpenInternal(rxfilename, false, NULL); +} + +bool Input::IsOpen() { + return impl_ != NULL; +} + +bool Output::IsOpen() { + return impl_ != NULL; +} + + +} // end namespace kaldi. + + +#endif // KALDI_UTIL_KALDI_IO_INL_H_ diff --git a/speechx/speechx/kaldi/util/kaldi-io.cc b/speechx/speechx/kaldi/util/kaldi-io.cc new file mode 100644 index 00000000..96cd8fa1 --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-io.cc @@ -0,0 +1,884 @@ +// util/kaldi-io.cc + +// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky +// 2016 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#include "util/kaldi-io.h" +#include +#include +#include "base/kaldi-math.h" +#include "util/text-utils.h" +#include "util/parse-options.h" +#include "util/kaldi-holder.h" +#include "util/kaldi-pipebuf.h" +#include "util/kaldi-table.h" // for Classify{W,R}specifier +#include +#include + +#ifdef KALDI_CYGWIN_COMPAT +#include "util/kaldi-cygwin-io-inl.h" +#define MapOsPath(x) MapCygwinPath(x) +#else // KALDI_CYGWIN_COMPAT +#define MapOsPath(x) x +#endif // KALDI_CYGWIN_COMPAT + + +#if defined(_MSC_VER) +static FILE *popen(const char* command, const char* mode) { +#ifdef KALDI_CYGWIN_COMPAT + return kaldi::CygwinCompatPopen(command, mode); +#else // KALDI_CYGWIN_COMPAT + return _popen(command, mode); +#endif // KALDI_CYGWIN_COMPAT +} +#endif // _MSC_VER + +namespace kaldi { + +#ifndef _MSC_VER // on VS, we don't need this type. +// could replace basic_pipebuf with stdio_filebuf on some platforms. +// Would mean we could use less of our own code. +typedef basic_pipebuf PipebufType; +#endif +} + +namespace kaldi { + +std::string PrintableRxfilename(const std::string &rxfilename) { + if (rxfilename == "" || rxfilename == "-") { + return "standard input"; + } else { + // If this call to Escape later causes compilation issues, + // just replace it with "return rxfilename"; it's only a + // pretty-printing issue. + return ParseOptions::Escape(rxfilename); + } +} + + +std::string PrintableWxfilename(const std::string &wxfilename) { + if (wxfilename == "" || wxfilename == "-") { + return "standard output"; + } else { + // If this call to Escape later causes compilation issues, + // just replace it with "return wxfilename"; it's only a + // pretty-printing issue. + return ParseOptions::Escape(wxfilename); + } +} + + +OutputType ClassifyWxfilename(const std::string &filename) { + const char *c = filename.c_str(); + size_t length = filename.length(); + char first_char = c[0], + last_char = (length == 0 ? '\0' : c[filename.length()-1]); + + // if 'filename' is "" or "-", return kStandardOutput. + if (length == 0 || (length == 1 && first_char == '-')) + return kStandardOutput; + else if (first_char == '|') return kPipeOutput; // An output pipe like "|blah". + else if (isspace(first_char) || isspace(last_char) || last_char == '|') { + return kNoOutput; // Leading or trailing space: can't interpret this. + // Final '|' would represent an input pipe, not an + // output pipe. + } else if ((first_char == 'a' || first_char == 's') && + strchr(c, ':') != NULL && + (ClassifyWspecifier(filename, NULL, NULL, NULL) != kNoWspecifier || + ClassifyRspecifier(filename, NULL, NULL) != kNoRspecifier)) { + // e.g. ark:something or scp:something... this is almost certainly a + // scripting error, so call it an error rather than treating it as a file. + // In practice in modern kaldi scripts all (r,w)filenames begin with "ark" + // or "scp", even though technically speaking options like "b", "t", "s" or + // "cs" can appear before the ark or scp, like "b,ark". For efficiency, + // and because this code is really just a nicety to catch errors earlier + // than they would otherwise be caught, we only call those extra functions + // for filenames beginning with 'a' or 's'. + return kNoOutput; + } else if (isdigit(last_char)) { + // This could be a file, but we have to see if it's an offset into a file + // (like foo.ark:4314328), which is not allowed for writing (but is + // allowed for reaching). This eliminates some things which would be + // valid UNIX filenames but are not allowed by Kaldi. (Even if we allowed + // such filenames for writing, we woudln't be able to correctly read them). + const char *d = c + length - 1; + while (isdigit(*d) && d > c) d--; + if (*d == ':') return kNoOutput; + // else it could still be a filename; continue to the next check. + } + + // At this point it matched no other pattern so we assume a filename, but we + // check for internal '|' as it's a common source of errors to have pipe + // commands without the pipe in the right place. Say that it can't be + // classified. + if (strchr(c, '|') != NULL) { + KALDI_WARN << "Trying to classify wxfilename with pipe symbol in the" + " wrong place (pipe without | at the beginning?): " << + filename; + return kNoOutput; + } + return kFileOutput; // It matched no other pattern: assume it's a filename. +} + + +InputType ClassifyRxfilename(const std::string &filename) { + const char *c = filename.c_str(); + size_t length = filename.length(); + char first_char = c[0], + last_char = (length == 0 ? '\0' : c[filename.length()-1]); + + // if 'filename' is "" or "-", return kStandardInput. + if (length == 0 || (length == 1 && first_char == '-')) { + return kStandardInput; + } else if (first_char == '|') { + return kNoInput; // An output pipe like "|blah": not + // valid for input. + } else if (last_char == '|') { + return kPipeInput; + } else if (isspace(first_char) || isspace(last_char)) { + return kNoInput; // We don't allow leading or trailing space in a filename. + } else if ((first_char == 'a' || first_char == 's') && + strchr(c, ':') != NULL && + (ClassifyWspecifier(filename, NULL, NULL, NULL) != kNoWspecifier || + ClassifyRspecifier(filename, NULL, NULL) != kNoRspecifier)) { + // e.g. ark:something or scp:something... this is almost certainly a + // scripting error, so call it an error rather than treating it as a file. + // In practice in modern kaldi scripts all (r,w)filenames begin with "ark" + // or "scp", even though technically speaking options like "b", "t", "s" or + // "cs" can appear before the ark or scp, like "b,ark". For efficiency, + // and because this code is really just a nicety to catch errors earlier + // than they would otherwise be caught, we only call those extra functions + // for filenames beginning with 'a' or 's'. + return kNoInput; + } else if (isdigit(last_char)) { + const char *d = c + length - 1; + while (isdigit(*d) && d > c) d--; + if (*d == ':') return kOffsetFileInput; // Filename is like + // some_file:12345 + // otherwise it could still be a filename; continue to the next check. + } + + + // At this point it matched no other pattern so we assume a filename, but + // we check for '|' as it's a common source of errors to have pipe + // commands without the pipe in the right place. Say that it can't be + // classified in this case. + if (strchr(c, '|') != NULL) { + KALDI_WARN << "Trying to classify rxfilename with pipe symbol in the" + " wrong place (pipe without | at the end?): " << filename; + return kNoInput; + } + return kFileInput; // It matched no other pattern: assume it's a filename. +} + +class OutputImplBase { + public: + // Open will open it as a file (no header), and return true + // on success. It cannot be called on an already open stream. + virtual bool Open(const std::string &filename, bool binary) = 0; + virtual std::ostream &Stream() = 0; + virtual bool Close() = 0; + virtual ~OutputImplBase() { } +}; + + +class FileOutputImpl: public OutputImplBase { + public: + virtual bool Open(const std::string &filename, bool binary) { + if (os_.is_open()) KALDI_ERR << "FileOutputImpl::Open(), " + << "open called on already open file."; + filename_ = filename; + os_.open(MapOsPath(filename_).c_str(), + binary ? std::ios_base::out | std::ios_base::binary + : std::ios_base::out); + return os_.is_open(); + } + + virtual std::ostream &Stream() { + if (!os_.is_open()) + KALDI_ERR << "FileOutputImpl::Stream(), file is not open."; + // I believe this error can only arise from coding error. + return os_; + } + + virtual bool Close() { + if (!os_.is_open()) + KALDI_ERR << "FileOutputImpl::Close(), file is not open."; + // I believe this error can only arise from coding error. + os_.close(); + return !(os_.fail()); + } + virtual ~FileOutputImpl() { + if (os_.is_open()) { + os_.close(); + if (os_.fail()) + KALDI_ERR << "Error closing output file " << filename_; + } + } + private: + std::string filename_; + std::ofstream os_; +}; + +class StandardOutputImpl: public OutputImplBase { + public: + StandardOutputImpl(): is_open_(false) { } + + virtual bool Open(const std::string &filename, bool binary) { + if (is_open_) KALDI_ERR << "StandardOutputImpl::Open(), " + "open called on already open file."; +#ifdef _MSC_VER + _setmode(_fileno(stdout), binary ? _O_BINARY : _O_TEXT); +#endif + is_open_ = std::cout.good(); + return is_open_; + } + + virtual std::ostream &Stream() { + if (!is_open_) + KALDI_ERR << "StandardOutputImpl::Stream(), object not initialized."; + // I believe this error can only arise from coding error. + return std::cout; + } + + virtual bool Close() { + if (!is_open_) + KALDI_ERR << "StandardOutputImpl::Close(), file is not open."; + is_open_ = false; + std::cout << std::flush; + return !(std::cout.fail()); + } + virtual ~StandardOutputImpl() { + if (is_open_) { + std::cout << std::flush; + if (std::cout.fail()) + KALDI_ERR << "Error writing to standard output"; + } + } + private: + bool is_open_; +}; + +class PipeOutputImpl: public OutputImplBase { + public: + PipeOutputImpl(): f_(NULL), os_(NULL) { } + + virtual bool Open(const std::string &wxfilename, bool binary) { + filename_ = wxfilename; + KALDI_ASSERT(f_ == NULL); // Make sure closed. + KALDI_ASSERT(wxfilename.length() != 0 && wxfilename[0] == '|'); // should + // start with '|' + std::string cmd_name(wxfilename, 1); +#if defined(_MSC_VER) || defined(__CYGWIN__) + f_ = popen(cmd_name.c_str(), (binary ? "wb" : "w")); +#else + f_ = popen(cmd_name.c_str(), "w"); +#endif + if (!f_) { // Failure. + KALDI_WARN << "Failed opening pipe for writing, command is: " + << cmd_name << ", errno is " << strerror(errno); + return false; + } else { +#ifndef _MSC_VER + fb_ = new PipebufType(f_, // Using this constructor won't make the + // destructor try to close the stream when + // we're done. + (binary ? std::ios_base::out| + std::ios_base::binary + :std::ios_base::out)); + KALDI_ASSERT(fb_ != NULL); // or would be alloc error. + os_ = new std::ostream(fb_); +#else + os_ = new std::ofstream(f_); +#endif + return os_->good(); + } + } + + virtual std::ostream &Stream() { + if (os_ == NULL) KALDI_ERR << "PipeOutputImpl::Stream()," + " object not initialized."; + // I believe this error can only arise from coding error. + return *os_; + } + + virtual bool Close() { + if (os_ == NULL) KALDI_ERR << "PipeOutputImpl::Close(), file is not open."; + bool ok = true; + os_->flush(); + if (os_->fail()) ok = false; + delete os_; + os_ = NULL; + int status; +#ifdef _MSC_VER + status = _pclose(f_); +#else + status = pclose(f_); +#endif + if (status) + KALDI_WARN << "Pipe " << filename_ << " had nonzero return status " + << status; + f_ = NULL; +#ifndef _MSC_VER + delete fb_; + fb_ = NULL; +#endif + return ok; + } + virtual ~PipeOutputImpl() { + if (os_) { + if (!Close()) + KALDI_ERR << "Error writing to pipe " << PrintableWxfilename(filename_); + } + } + private: + std::string filename_; + FILE *f_; +#ifndef _MSC_VER + PipebufType *fb_; +#endif + std::ostream *os_; +}; + + + +class InputImplBase { + public: + // Open will open it as a file, and return true on success. + // May be called twice only for kOffsetFileInput (otherwise, + // if called twice, we just create a new Input object, to avoid + // having to deal with the extra hassle of reopening with the + // same object. + // Note that we will to call Open with true (binary) for + // for text-mode Kaldi files; the only actual text-mode input + // is for non-Kaldi files. + virtual bool Open(const std::string &filename, bool binary) = 0; + virtual std::istream &Stream() = 0; + virtual int32 Close() = 0; // We only need to check failure in the case of + // kPipeInput. + // on close for input streams. + virtual InputType MyType() = 0; // Because if it's kOffsetFileInput, we may + // call Open twice + // (has efficiency benefits). + + virtual ~InputImplBase() { } +}; + +class FileInputImpl: public InputImplBase { + public: + virtual bool Open(const std::string &filename, bool binary) { + if (is_.is_open()) KALDI_ERR << "FileInputImpl::Open(), " + << "open called on already open file."; + is_.open(MapOsPath(filename).c_str(), + binary ? std::ios_base::in | std::ios_base::binary + : std::ios_base::in); + return is_.is_open(); + } + + virtual std::istream &Stream() { + if (!is_.is_open()) + KALDI_ERR << "FileInputImpl::Stream(), file is not open."; + // I believe this error can only arise from coding error. + return is_; + } + + virtual int32 Close() { + if (!is_.is_open()) + KALDI_ERR << "FileInputImpl::Close(), file is not open."; + // I believe this error can only arise from coding error. + is_.close(); + // Don't check status. + return 0; + } + + virtual InputType MyType() { return kFileInput; } + + virtual ~FileInputImpl() { + // Stream will automatically be closed, and we don't care about + // whether it fails. + } + private: + std::ifstream is_; +}; + + +class StandardInputImpl: public InputImplBase { + public: + StandardInputImpl(): is_open_(false) { } + + virtual bool Open(const std::string &filename, bool binary) { + if (is_open_) KALDI_ERR << "StandardInputImpl::Open(), " + "open called on already open file."; + is_open_ = true; +#ifdef _MSC_VER + _setmode(_fileno(stdin), binary ? _O_BINARY : _O_TEXT); +#endif + return true; // Don't check good() because would be false if + // eof, which may be valid input. + } + + virtual std::istream &Stream() { + if (!is_open_) + KALDI_ERR << "StandardInputImpl::Stream(), object not initialized."; + // I believe this error can only arise from coding error. + return std::cin; + } + + virtual InputType MyType() { return kStandardInput; } + + virtual int32 Close() { + if (!is_open_) KALDI_ERR << "StandardInputImpl::Close(), file is not open."; + is_open_ = false; + return 0; + } + virtual ~StandardInputImpl() { } + private: + bool is_open_; +}; + +class PipeInputImpl: public InputImplBase { + public: + PipeInputImpl(): f_(NULL), is_(NULL) { } + + virtual bool Open(const std::string &rxfilename, bool binary) { + filename_ = rxfilename; + KALDI_ASSERT(f_ == NULL); // Make sure closed. + KALDI_ASSERT(rxfilename.length() != 0 && + rxfilename[rxfilename.length()-1] == '|'); // should end with '|' + std::string cmd_name(rxfilename, 0, rxfilename.length()-1); +#if defined(_MSC_VER) || defined(__CYGWIN__) + f_ = popen(cmd_name.c_str(), (binary ? "rb" : "r")); +#else + f_ = popen(cmd_name.c_str(), "r"); +#endif + + if (!f_) { // Failure. + KALDI_WARN << "Failed opening pipe for reading, command is: " + << cmd_name << ", errno is " << strerror(errno); + return false; + } else { +#ifndef _MSC_VER + fb_ = new PipebufType(f_, // Using this constructor won't lead the + // destructor to close the stream. + (binary ? std::ios_base::in| + std::ios_base::binary + :std::ios_base::in)); + KALDI_ASSERT(fb_ != NULL); // or would be alloc error. + is_ = new std::istream(fb_); +#else + is_ = new std::ifstream(f_); +#endif + if (is_->fail() || is_->bad()) return false; + if (is_->eof()) { + KALDI_WARN << "Pipe opened with command " + << PrintableRxfilename(rxfilename) + << " is empty."; + // don't return false: empty may be valid. + } + return true; + } + } + + virtual std::istream &Stream() { + if (is_ == NULL) + KALDI_ERR << "PipeInputImpl::Stream(), object not initialized."; + // I believe this error can only arise from coding error. + return *is_; + } + + virtual int32 Close() { + if (is_ == NULL) + KALDI_ERR << "PipeInputImpl::Close(), file is not open."; + delete is_; + is_ = NULL; + int32 status; +#ifdef _MSC_VER + status = _pclose(f_); +#else + status = pclose(f_); +#endif + if (status) + KALDI_WARN << "Pipe " << filename_ << " had nonzero return status " + << status; + f_ = NULL; +#ifndef _MSC_VER + delete fb_; + fb_ = NULL; +#endif + return status; + } + virtual ~PipeInputImpl() { + if (is_) + Close(); + } + virtual InputType MyType() { return kPipeInput; } + private: + std::string filename_; + FILE *f_; +#ifndef _MSC_VER + PipebufType *fb_; +#endif + std::istream *is_; +}; + +/* +#else + +// Just have an empty implementation of the pipe input that crashes if +// called. +class PipeInputImpl: public InputImplBase { + public: + PipeInputImpl() { KALDI_ASSERT(0 && "Pipe input not yet supported on this + platform."); } + virtual bool Open(const std::string, bool) { return 0; } + virtual std::istream &Stream() const { return NULL; } + virtual void Close() {} + virtual InputType MyType() { return kPipeInput; } +}; + +#endif +*/ + +class OffsetFileInputImpl: public InputImplBase { + // This class is a bit more complicated than the + + public: + // splits a filename like /my/file:123 into /my/file and the + // number 123. Crashes if not this format. + static void SplitFilename(const std::string &rxfilename, + std::string *filename, + size_t *offset) { + size_t pos = rxfilename.find_last_of(':'); + KALDI_ASSERT(pos != std::string::npos); // would indicate error in calling + // code, as the filename is supposed to be of the correct form at this + // point. + *filename = std::string(rxfilename, 0, pos); + std::string number(rxfilename, pos+1); + bool ans = ConvertStringToInteger(number, offset); + if (!ans) + KALDI_ERR << "Cannot get offset from filename " << rxfilename + << " (possibly you compiled in 32-bit and have a >32-bit" + << " byte offset into a file; you'll have to compile 64-bit."; + } + + bool Seek(size_t offset) { + size_t cur_pos = is_.tellg(); + if (cur_pos == offset) return true; + else if (cur_pos offset) { + // We're close enough that it may be faster to just + // read that data, rather than seek. + for (size_t i = cur_pos; i < offset; i++) + is_.get(); + return (is_.tellg() == std::streampos(offset)); + } + // Try to actually seek. + is_.seekg(offset, std::ios_base::beg); + if (is_.fail()) { // failbit or badbit is set [error happened] + is_.close(); + return false; // failure. + } else { + is_.clear(); // Clear any failure bits (e.g. eof). + return true; // success. + } + } + + // This Open routine is unusual in that it is designed to work even + // if it was already open. This for efficiency when seeking multiple + // times. + virtual bool Open(const std::string &rxfilename, bool binary) { + if (is_.is_open()) { + // We are opening when we have an already-open file. + // We may have to seek within this file, or else close it and + // open a different one. + std::string tmp_filename; + size_t offset; + SplitFilename(rxfilename, &tmp_filename, &offset); + if (tmp_filename == filename_ && binary == binary_) { // Just seek + is_.clear(); // clear fail bit, etc. + return Seek(offset); + } else { + is_.close(); // don't bother checking error status of is_. + filename_ = tmp_filename; + is_.open(MapOsPath(filename_).c_str(), + binary ? std::ios_base::in | std::ios_base::binary + : std::ios_base::in); + if (!is_.is_open()) return false; + else + return Seek(offset); + } + } else { + size_t offset; + SplitFilename(rxfilename, &filename_, &offset); + binary_ = binary; + is_.open(MapOsPath(filename_).c_str(), + binary ? std::ios_base::in | std::ios_base::binary + : std::ios_base::in); + if (!is_.is_open()) return false; + else + return Seek(offset); + } + } + + virtual std::istream &Stream() { + if (!is_.is_open()) + KALDI_ERR << "FileInputImpl::Stream(), file is not open."; + // I believe this error can only arise from coding error. + return is_; + } + + virtual int32 Close() { + if (!is_.is_open()) + KALDI_ERR << "FileInputImpl::Close(), file is not open."; + // I believe this error can only arise from coding error. + is_.close(); + // Don't check status. + return 0; + } + + virtual InputType MyType() { return kOffsetFileInput; } + + virtual ~OffsetFileInputImpl() { + // Stream will automatically be closed, and we don't care about + // whether it fails. + } + private: + std::string filename_; // the actual filename + bool binary_; // true if was opened in binary mode. + std::ifstream is_; +}; + + +Output::Output(const std::string &wxfilename, bool binary, + bool write_header):impl_(NULL) { + if (!Open(wxfilename, binary, write_header)) { + if (impl_) { + delete impl_; + impl_ = NULL; + } + KALDI_ERR << "Error opening output stream " << + PrintableWxfilename(wxfilename); + } +} + +bool Output::Close() { + if (!impl_) { + return false; // error to call Close if not open. + } else { + bool ans = impl_->Close(); + delete impl_; + impl_ = NULL; + return ans; + } +} + +Output::~Output() { + if (impl_) { + bool ok = impl_->Close(); + delete impl_; + impl_ = NULL; + if (!ok) + KALDI_ERR << "Error closing output file " + << PrintableWxfilename(filename_) + << (ClassifyWxfilename(filename_) == kFileOutput ? + " (disk full?)" : ""); + } +} + +std::ostream &Output::Stream() { // will throw if not open; else returns + // stream. + if (!impl_) KALDI_ERR << "Output::Stream() called but not open."; + return impl_->Stream(); +} + +bool Output::Open(const std::string &wxfn, bool binary, bool header) { + if (IsOpen()) { + if (!Close()) { // Throw here rather than return status, as it's an error + // about something else: if the user wanted to avoid the exception he/she + // could have called Close(). + KALDI_ERR << "Output::Open(), failed to close output stream: " + << PrintableWxfilename(filename_); + } + } + + filename_ = wxfn; + + OutputType type = ClassifyWxfilename(wxfn); + KALDI_ASSERT(impl_ == NULL); + + if (type == kFileOutput) { + impl_ = new FileOutputImpl(); + } else if (type == kStandardOutput) { + impl_ = new StandardOutputImpl(); + } else if (type == kPipeOutput) { + impl_ = new PipeOutputImpl(); + } else { // type == kNoOutput + KALDI_WARN << "Invalid output filename format "<< + PrintableWxfilename(wxfn); + return false; + } + if (!impl_->Open(wxfn, binary)) { + delete impl_; + impl_ = NULL; + return false; // failed to open. + } else { // successfully opened it. + if (header) { + InitKaldiOutputStream(impl_->Stream(), binary); + bool ok = impl_->Stream().good(); // still OK? + if (!ok) { + delete impl_; + impl_ = NULL; + return false; + } + return true; + } else { + return true; + } + } +} + + +Input::Input(const std::string &rxfilename, bool *binary): impl_(NULL) { + if (!Open(rxfilename, binary)) { + KALDI_ERR << "Error opening input stream " + << PrintableRxfilename(rxfilename); + } +} + +int32 Input::Close() { + if (impl_) { + int32 ans = impl_->Close(); + delete impl_; + impl_ = NULL; + return ans; + } else { + return 0; + } +} + +bool Input::OpenInternal(const std::string &rxfilename, + bool file_binary, + bool *contents_binary) { + InputType type = ClassifyRxfilename(rxfilename); + if (IsOpen()) { + // May have to close the stream first. + if (type == kOffsetFileInput && impl_->MyType() == kOffsetFileInput) { + // We want to use the same object to Open... this is in case + // the files are the same, so we can just seek. + if (!impl_->Open(rxfilename, file_binary)) { // true is binary mode-- + // always open in binary. + delete impl_; + impl_ = NULL; + return false; + } + // read the binary header, if requested. + if (contents_binary != NULL) + return InitKaldiInputStream(impl_->Stream(), contents_binary); + else + return true; + } else { + Close(); + // and fall through to code below which actually opens the file. + } + } + if (type == kFileInput) { + impl_ = new FileInputImpl(); + } else if (type == kStandardInput) { + impl_ = new StandardInputImpl(); + } else if (type == kPipeInput) { + impl_ = new PipeInputImpl(); + } else if (type == kOffsetFileInput) { + impl_ = new OffsetFileInputImpl(); + } else { // type == kNoInput + KALDI_WARN << "Invalid input filename format "<< + PrintableRxfilename(rxfilename); + return false; + } + if (!impl_->Open(rxfilename, file_binary)) { // true is binary mode-- + // always read in binary. + delete impl_; + impl_ = NULL; + return false; + } + if (contents_binary != NULL) + return InitKaldiInputStream(impl_->Stream(), contents_binary); + else + return true; +} + + +Input::~Input() { if (impl_) Close(); } + + +std::istream &Input::Stream() { + if (!IsOpen()) KALDI_ERR << "Input::Stream(), not open."; + return impl_->Stream(); +} + + +template <> void ReadKaldiObject(const std::string &filename, + Matrix *m) { + if (!filename.empty() && filename[filename.size() - 1] == ']') { + // This filename seems to have a 'range'... like foo.ark:4312423[20:30]. + // (the bit in square brackets is the range). + std::string rxfilename, range; + if (!ExtractRangeSpecifier(filename, &rxfilename, &range)) { + KALDI_ERR << "Could not make sense of possible range specifier in filename " + << "while reading matrix: " << filename; + } + Matrix temp; + bool binary_in; + Input ki(rxfilename, &binary_in); + temp.Read(ki.Stream(), binary_in); + if (!ExtractObjectRange(temp, range, m)) { + KALDI_ERR << "Error extracting range of object: " << filename; + } + } else { + // The normal case, there is no range. + bool binary_in; + Input ki(filename, &binary_in); + m->Read(ki.Stream(), binary_in); + } +} + +template <> void ReadKaldiObject(const std::string &filename, + Matrix *m) { + if (!filename.empty() && filename[filename.size() - 1] == ']') { + // This filename seems to have a 'range'... like foo.ark:4312423[20:30]. + // (the bit in square brackets is the range). + std::string rxfilename, range; + if (!ExtractRangeSpecifier(filename, &rxfilename, &range)) { + KALDI_ERR << "Could not make sense of possible range specifier in filename " + << "while reading matrix: " << filename; + } + Matrix temp; + bool binary_in; + Input ki(rxfilename, &binary_in); + temp.Read(ki.Stream(), binary_in); + if (!ExtractObjectRange(temp, range, m)) { + KALDI_ERR << "Error extracting range of object: " << filename; + } + } else { + // The normal case, there is no range. + bool binary_in; + Input ki(filename, &binary_in); + m->Read(ki.Stream(), binary_in); + } +} + + + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/util/kaldi-io.h b/speechx/speechx/kaldi/util/kaldi-io.h new file mode 100644 index 00000000..c28be8a6 --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-io.h @@ -0,0 +1,280 @@ +// util/kaldi-io.h + +// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky +// 2016 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_KALDI_IO_H_ +#define KALDI_UTIL_KALDI_IO_H_ + +#ifdef _MSC_VER +# include +# include +#endif +#include // For isspace. +#include +#include +#include "base/kaldi-common.h" +#include "matrix/kaldi-matrix.h" + + +namespace kaldi { + +class OutputImplBase; // Forward decl; defined in a .cc file +class InputImplBase; // Forward decl; defined in a .cc file + +/// \addtogroup io_group +/// @{ + +// The Output and Input classes handle stream-opening for "extended" filenames +// that include actual files, standard-input/standard-output, pipes, and +// offsets into actual files. They also handle reading and writing the +// binary-mode headers for Kaldi files, where applicable. The classes have +// versions of the Open routines that throw and do not throw, depending whether +// the calling code wants to catch the errors or not; there are also versions +// that write (or do not write) the Kaldi binary-mode header that says if it's +// binary mode. Generally files that contain Kaldi objects will have the header +// on, so we know upon reading them whether they have the header. So you would +// use the OpenWithHeader routines for these (or the constructor); but other +// types of objects (e.g. FSTs) would have files without a header so you would +// use OpenNoHeader. + +// We now document the types of extended filenames that we use. +// +// A "wxfilename" is an extended filename for writing. It can take three forms: +// (1) Filename: e.g. "/some/filename", "./a/b/c", "c:\Users\dpovey\My +// Documents\\boo" +// (whatever the actual file-system interprets) +// (2) Standard output: "" or "-" +// (3) A pipe: e.g. "| gzip -c > /tmp/abc.gz" +// +// +// A "rxfilename" is an extended filename for reading. It can take four forms: +// (1) An actual filename, whatever the file-system can read, e.g. "/my/file". +// (2) Standard input: "" or "-" +// (3) A pipe: e.g. "gunzip -c /tmp/abc.gz |" +// (4) An offset into a file, e.g.: "/mnt/blah/data/1.ark:24871" +// [these are created by the Table and TableWriter classes; I may also write +// a program that creates them for arbitrary files] +// + + +// Typical usage: +// ... +// bool binary; +// MyObject.Write(Output(some_filename, binary).Stream(), binary); +// +// ... more extensive example: +// { +// Output ko(some_filename, binary); +// MyObject1.Write(ko.Stream(), binary); +// MyObject2.Write(ko.Stream(), binary); +// } + + + +enum OutputType { + kNoOutput, + kFileOutput, + kStandardOutput, + kPipeOutput +}; + +/// ClassifyWxfilename interprets filenames as follows: +/// - kNoOutput: invalid filenames (leading or trailing space, things that look +/// like wspecifiers and rspecifiers or like pipes to read from with leading +/// |. +/// - kFileOutput: Normal filenames +/// - kStandardOutput: The empty string or "-", interpreted as standard output +/// - kPipeOutput: pipes, e.g. "| gzip -c > /tmp/abc.gz" +OutputType ClassifyWxfilename(const std::string &wxfilename); + +enum InputType { + kNoInput, + kFileInput, + kStandardInput, + kOffsetFileInput, + kPipeInput +}; + +/// ClassifyRxfilenames interprets filenames for reading as follows: +/// - kNoInput: invalid filenames (leading or trailing space, things that +/// look like wspecifiers and rspecifiers or pipes to write to +/// with trailing |. +/// - kFileInput: normal filenames +/// - kStandardInput: the empty string or "-" +/// - kPipeInput: e.g. "gunzip -c /tmp/abc.gz |" +/// - kOffsetFileInput: offsets into files, e.g. /some/filename:12970 +InputType ClassifyRxfilename(const std::string &rxfilename); + + +class Output { + public: + // The normal constructor, provided for convenience. + // Equivalent to calling with default constructor then Open() + // with these arguments. + Output(const std::string &filename, bool binary, bool write_header = true); + + Output(): impl_(NULL) {} + + /// This opens the stream, with the given mode (binary or text). It returns + /// true on success and false on failure. However, it will throw if something + /// was already open and could not be closed (to avoid this, call Close() + /// first. if write_header == true and binary == true, it writes the Kaldi + /// binary-mode header ('\0' then 'B'). You may call Open even if it is + /// already open; it will close the existing stream and reopen (however if + /// closing the old stream failed it will throw). + bool Open(const std::string &wxfilename, bool binary, bool write_header); + + inline bool IsOpen(); // return true if we have an open stream. Does not + // imply stream is good for writing. + + std::ostream &Stream(); // will throw if not open; else returns stream. + + // Close closes the stream. Calling Close is never necessary unless you + // want to avoid exceptions being thrown. There are times when calling + // Close will hurt efficiency (basically, when using offsets into files, + // and using the same Input object), + // but most of the time the user won't be doing this directly, it will + // be done in kaldi-table.{h, cc}, so you don't have to worry about it. + bool Close(); + + // This will throw if stream could not be closed (to check error status, + // call Close()). + ~Output(); + + private: + OutputImplBase *impl_; // non-NULL if open. + std::string filename_; + KALDI_DISALLOW_COPY_AND_ASSIGN(Output); +}; + + +// bool binary_in; +// Input ki(some_filename, &binary_in); +// MyObject.Read(ki.Stream(), binary_in); +// +// ... more extensive example: +// +// { +// bool binary_in; +// Input ki(some_filename, &binary_in); +// MyObject1.Read(ki.Stream(), &binary_in); +// MyObject2.Write(ki.Stream(), &binary_in); +// } +// Note that to catch errors you need to use try.. catch. +// Input communicates errors by throwing exceptions. + + +// Input interprets four kinds of filenames: +// (1) Normal filenames +// (2) The empty string or "-", interpreted as standard output +// (3) A pipe: e.g. "gunzip -c /tmp/abc.gz |" +// (4) Offsets into [real] files, e.g. "/my/filename:12049" +// The last one has no correspondence in Output. + + +class Input { + public: + /// The normal constructor. Opens the stream in binary mode. + /// Equivalent to calling the default constructor followed by Open(); then, if + /// binary != NULL, it calls ReadHeader(), putting the output in "binary"; it + /// throws on error. + Input(const std::string &rxfilename, bool *contents_binary = NULL); + + Input(): impl_(NULL) {} + + // Open opens the stream for reading (the mode, where relevant, is binary; use + // OpenTextMode for text-mode, we made this a separate function rather than a + // boolean argument, to avoid confusion with Kaldi's text/binary distinction, + // since reading in the file system's text mode is unusual.) If + // contents_binary != NULL, it reads the binary-mode header and puts it in the + // "binary" variable. Returns true on success. If it returns false it will + // not be open. You may call Open even if it is already open; it will close + // the existing stream and reopen (however if closing the old stream failed it + // will throw). + inline bool Open(const std::string &rxfilename, bool *contents_binary = NULL); + + // As Open but (if the file system has text/binary modes) opens in text mode; + // you shouldn't ever have to use this as in Kaldi we read even text files in + // binary mode (and ignore the \r). + inline bool OpenTextMode(const std::string &rxfilename); + + // Return true if currently open for reading and Stream() will + // succeed. Does not guarantee that the stream is good. + inline bool IsOpen(); + + // It is never necessary or helpful to call Close, except if + // you are concerned about to many filehandles being open. + // Close does not throw. It returns the exit code as int32 + // in the case of a pipe [kPipeInput], and always zero otherwise. + int32 Close(); + + // Returns the underlying stream. Throws if !IsOpen() + std::istream &Stream(); + + // Destructor does not throw: input streams may legitimately fail so we + // don't worry about the status when we close them. + ~Input(); + private: + bool OpenInternal(const std::string &rxfilename, bool file_binary, + bool *contents_binary); + InputImplBase *impl_; + KALDI_DISALLOW_COPY_AND_ASSIGN(Input); +}; + +template void ReadKaldiObject(const std::string &filename, + C *c) { + bool binary_in; + Input ki(filename, &binary_in); + c->Read(ki.Stream(), binary_in); +} + +// Specialize the template for reading matrices, because we want to be able to +// support reading 'ranges' (row and column ranges), like foo.mat[10:20]. +template <> void ReadKaldiObject(const std::string &filename, + Matrix *m); + + +template <> void ReadKaldiObject(const std::string &filename, + Matrix *m); + + + +template inline void WriteKaldiObject(const C &c, + const std::string &filename, + bool binary) { + Output ko(filename, binary); + c.Write(ko.Stream(), binary); +} + +/// PrintableRxfilename turns the rxfilename into a more human-readable +/// form for error reporting, i.e. it does quoting and escaping and +/// replaces "" or "-" with "standard input". +std::string PrintableRxfilename(const std::string &rxfilename); + +/// PrintableWxfilename turns the wxfilename into a more human-readable +/// form for error reporting, i.e. it does quoting and escaping and +/// replaces "" or "-" with "standard output". +std::string PrintableWxfilename(const std::string &wxfilename); + +/// @} + +} // end namespace kaldi. + +#include "util/kaldi-io-inl.h" + +#endif // KALDI_UTIL_KALDI_IO_H_ diff --git a/speechx/speechx/kaldi/util/kaldi-pipebuf.h b/speechx/speechx/kaldi/util/kaldi-pipebuf.h new file mode 100644 index 00000000..61034ac2 --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-pipebuf.h @@ -0,0 +1,87 @@ +// util/kaldi-pipebuf.h + +// Copyright 2009-2011 Ondrej Glembek + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +/** @file kaldi-pipebuf.h + * This is an Kaldi C++ Library header. + */ + +#ifndef KALDI_UTIL_KALDI_PIPEBUF_H_ +#define KALDI_UTIL_KALDI_PIPEBUF_H_ + +#include +#if !defined(_LIBCPP_VERSION) // libc++ +#include +#else +#include "util/basic-filebuf.h" +#endif + +namespace kaldi { +// This class provides a way to initialize a filebuf with a FILE* pointer +// directly; it will not close the file pointer when it is deleted. +// The C++ standard does not allow implementations of C++ to provide +// this constructor within basic_filebuf, which makes it hard to deal +// with pipes using completely native C++. This is a workaround + +#ifdef _MSC_VER +#elif defined(_LIBCPP_VERSION) // libc++ +template > +class basic_pipebuf : public basic_filebuf { + public: + typedef basic_pipebuf ThisType; + + public: + basic_pipebuf(FILE *fptr, std::ios_base::openmode mode) + : basic_filebuf() { + this->open(fptr, mode); + if (!this->is_open()) { + KALDI_WARN << "Error initializing pipebuf"; // probably indicates + // code error, if the fptr was good. + return; + } + } +}; // class basic_pipebuf +#else +template > +class basic_pipebuf : public std::basic_filebuf { + public: + typedef basic_pipebuf ThisType; + + public: + basic_pipebuf(FILE *fptr, std::ios_base::openmode mode) + : std::basic_filebuf() { + this->_M_file.sys_open(fptr, mode); + if (!this->is_open()) { + KALDI_WARN << "Error initializing pipebuf"; // probably indicates + // code error, if the fptr was good. + return; + } + this->_M_mode = mode; + this->_M_buf_size = BUFSIZ; + this->_M_allocate_internal_buffer(); + this->_M_reading = false; + this->_M_writing = false; + this->_M_set_buffer(-1); + } +}; // class basic_pipebuf +#endif // _MSC_VER + +} // namespace kaldi + +#endif // KALDI_UTIL_KALDI_PIPEBUF_H_ diff --git a/speechx/speechx/kaldi/util/kaldi-semaphore.cc b/speechx/speechx/kaldi/util/kaldi-semaphore.cc new file mode 100644 index 00000000..f0829ac0 --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-semaphore.cc @@ -0,0 +1,57 @@ +// util/kaldi-semaphore.cc + +// Copyright 2012 Karel Vesely (Brno University of Technology) +// 2017 Dogan Can (University of Southern California) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + + +#include "base/kaldi-error.h" +#include "util/kaldi-semaphore.h" + +namespace kaldi { + +Semaphore::Semaphore(int32 count) { + KALDI_ASSERT(count >= 0); + count_ = count; +} + +Semaphore::~Semaphore() {} + +bool Semaphore::TryWait() { + std::unique_lock lock(mutex_); + if(count_) { + count_--; + return true; + } + return false; +} + +void Semaphore::Wait() { + std::unique_lock lock(mutex_); + while(!count_) + condition_variable_.wait(lock); + count_--; +} + +void Semaphore::Signal() { + std::unique_lock lock(mutex_); + count_++; + condition_variable_.notify_one(); +} + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/util/kaldi-semaphore.h b/speechx/speechx/kaldi/util/kaldi-semaphore.h new file mode 100644 index 00000000..2562053c --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-semaphore.h @@ -0,0 +1,50 @@ +// util/kaldi-semaphore.h + +// Copyright 2012 Karel Vesely (Brno University of Technology) +// 2017 Dogan Can (University of Southern California) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_THREAD_KALDI_SEMAPHORE_H_ +#define KALDI_THREAD_KALDI_SEMAPHORE_H_ 1 + +#include +#include + +namespace kaldi { + +class Semaphore { + public: + Semaphore(int32 count = 0); + + ~Semaphore(); + + bool TryWait(); ///< Returns true if Wait() goes through + void Wait(); ///< decrease the counter + void Signal(); ///< increase the counter + + private: + int32 count_; ///< the semaphore counter, 0 means block on Wait() + + std::mutex mutex_; + std::condition_variable condition_variable_; + KALDI_DISALLOW_COPY_AND_ASSIGN(Semaphore); +}; + +} //namespace + +#endif // KALDI_THREAD_KALDI_SEMAPHORE_H_ diff --git a/speechx/speechx/kaldi/util/kaldi-table-inl.h b/speechx/speechx/kaldi/util/kaldi-table-inl.h new file mode 100644 index 00000000..6aca2f13 --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-table-inl.h @@ -0,0 +1,2672 @@ +// util/kaldi-table-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) +// 2016 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_KALDI_TABLE_INL_H_ +#define KALDI_UTIL_KALDI_TABLE_INL_H_ + +#include +#include +#include +#include +#include +#include +#include "util/kaldi-io.h" +#include "util/kaldi-holder.h" +#include "util/text-utils.h" +#include "util/stl-utils.h" // for StringHasher. +#include "util/kaldi-semaphore.h" + + +namespace kaldi { + +/// \addtogroup table_impl_types +/// @{ + +template class SequentialTableReaderImplBase { + public: + typedef typename Holder::T T; + // note that Open takes rxfilename not rspecifier. Open will only be + // called on a just-allocated object. + virtual bool Open(const std::string &rxfilename) = 0; + // Done() should be called on a successfully opened, not-closed object. + // only throws if called at the wrong time (i.e. code error). + virtual bool Done() const = 0; + // Returns true if the reader is open [i.e. Open() succeeded and + // the user has not called Close()] + virtual bool IsOpen() const = 0; + // Returns the current key; it is valid to call this if Done() returned false. + // Only throws on code error (i.e. called at the wrong time). + virtual std::string Key() = 0; + // Returns the value associated with the current key. Valid to call it if + // Done() returned false. It throws if the value could not be read. [However + // if you use the ,p modifier it will never throw, unless you call it at the + // wrong time, i.e. unless there is a code error.] + virtual T &Value() = 0; + virtual void FreeCurrent() = 0; + // move to the next object. This won't throw unless called wrongly (e.g. on + // non-open archive.] + virtual void Next() = 0; + // Close the table. Returns its status as bool so it won't throw, unless + // called wrongly [i.e. on non-open archive.] + virtual bool Close() = 0; + // SwapHolder() is not part of the public interface of SequentialTableReader. + // It should be called when it would be valid to call Value() or FreeCurrent() + // (i.e. when a value is stored), and after this it's not valid to get the + // value any more until you call Next(). It swaps the contents of + // this->holder_ with those of 'other_holder'. It's needed as part of how + // we implement SequentialTableReaderBackgroundImpl. + virtual void SwapHolder(Holder *other_holder) = 0; + SequentialTableReaderImplBase() { } + virtual ~SequentialTableReaderImplBase() { } // throws. + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(SequentialTableReaderImplBase); +}; + +// This is the implementation for SequentialTableReader +// when it's actually a script file. +template class SequentialTableReaderScriptImpl: + public SequentialTableReaderImplBase { + public: + typedef typename Holder::T T; + + SequentialTableReaderScriptImpl(): state_(kUninitialized) { } + + // You may call Open from states kUninitialized and kError. + // It may leave the object in any of the states. + virtual bool Open(const std::string &rspecifier) { + if (state_ != kUninitialized && state_ != kError) + if (!Close()) // call Close() yourself to suppress this exception. + KALDI_ERR << "Error closing previous input: " + << "rspecifier was " << rspecifier_; + bool binary; + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, &script_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kScriptRspecifier); + if (!script_input_.Open(script_rxfilename_, &binary)) { // Failure on Open + KALDI_WARN << "Failed to open script file " + << PrintableRxfilename(script_rxfilename_); + state_ = kUninitialized; + return false; + } else { // Open succeeded. + if (binary) { + KALDI_WARN << "Script file should not be binary file."; + SetErrorState(); + return false; + } else { + state_ = kFileStart; + Next(); + if (state_ == kError) + return false; + // any other status, including kEof, is OK from the point of view of + // the 'open' function (empty scp file is not inherently an error). + return true; + } + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kEof: case kHaveScpLine: case kHaveObject: case kHaveRange: + return true; + case kUninitialized: case kError: + return false; + default: KALDI_ERR << "IsOpen() called on invalid object."; + // note: kFileStart is not a valid state for the user to call a member + // function (we never return from a public function in this state). + return false; + } + } + + virtual bool Done() const { + switch (state_) { + case kHaveScpLine: case kHaveObject: case kHaveRange: return false; + case kEof: case kError: return true; // Error condition, like Eof, counts + // as Done(); the destructor/Close() will inform the user of the error. + default: KALDI_ERR << "Done() called on TableReader object at the wrong" + " time."; + return false; + } + } + + virtual std::string Key() { + // Valid to call this whenever Done() returns false. + switch (state_) { + case kHaveScpLine: case kHaveObject: case kHaveRange: break; + default: + // coding error. + KALDI_ERR << "Key() called on TableReader object at the wrong time."; + } + return key_; + } + + T &Value() { + if (!EnsureObjectLoaded()) + KALDI_ERR << "Failed to load object from " + << PrintableRxfilename(data_rxfilename_) + << " (to suppress this error, add the permissive " + << "(p, ) option to the rspecifier."; + // Because EnsureObjectLoaded() returned with success, we know + // that if range_ is nonempty (i.e. a range was requested), the + // state will be kHaveRange. + if (state_ == kHaveRange) { + return range_holder_.Value(); + } else { + KALDI_ASSERT(state_ == kHaveObject); + return holder_.Value(); + } + } + + void FreeCurrent() { + if (state_ == kHaveObject) { + holder_.Clear(); + state_ = kHaveScpLine; + } else if (state_ == kHaveRange) { + range_holder_.Clear(); + state_ = kHaveObject; + } else { + KALDI_WARN << "FreeCurrent called at the wrong time."; + } + } + + void SwapHolder(Holder *other_holder) { + // call Value() to ensure we have a value, and ignore its return value while + // suppressing compiler warnings by casting to void. It will cause the + // program to die with KALDI_ERR if we couldn't get a value. + (void) Value(); + // At this point we know that we successfully loaded an object, + // and if there was a range specified, it's in range_holder_. + if (state_ == kHaveObject) { + holder_.Swap(other_holder); + state_ = kHaveScpLine; + } else if (state_ == kHaveRange) { + range_holder_.Swap(other_holder); + state_ = kHaveObject; + // This indicates that we still have the base object (but no range). + } else { + KALDI_ERR << "Code error"; + } + // Note: after this call there may be some junk left in range_holder_ or + // holder_, but it won't matter. We avoid calling Clear() on them, as this + // function needs to be lightweight for the 'bg' feature to work well. + } + + // Next goes to the next object. + // It can leave the object in most of the statuses, but + // the only circumstances under which it will return are: + // either: + // - if Done() returned true, i.e. kError or kEof. + // or: + // - in non-permissive mode, status kHaveScpLine or kHaveObjecct + // - in permissive mode, only when we successfully have an object, + // which means either (kHaveObject and range_.empty()), or + // kHaveRange. + void Next() { + while (1) { + NextScpLine(); + if (Done()) return; + if (opts_.permissive) { + // Permissive mode means, when reading scp files, we treat keys whose + // scp entry cannot be read as nonexistent. This means trying to read. + if (EnsureObjectLoaded()) return; // Success. + // else try the next scp line. + } else { + return; // We go the next key; Value() will crash if we can't read the + // object on the scp line. + } + } + } + + // This function may be entered at in any state. At exit, the object will be + // in state kUninitialized. It only returns false in the situation where we + // were at the end of the stream (kEof) and the script_input_ was a pipe and + // it ended with error status; this is so that we can catch errors from + // programs that we invoked via a pipe. + virtual bool Close() { + int32 status = 0; + if (script_input_.IsOpen()) + status = script_input_.Close(); + if (data_input_.IsOpen()) + data_input_.Close(); + range_holder_.Clear(); + holder_.Clear(); + if (!this->IsOpen()) + KALDI_ERR << "Close() called on input that was not open."; + StateType old_state = state_; + state_ = kUninitialized; + if (old_state == kError || (old_state == kEof && status != 0)) { + if (opts_.permissive) { + KALDI_WARN << "Close() called on scp file with read error, ignoring the" + " error because permissive mode specified."; + return true; + } else { + return false; // User will do something with the error status. + } + } else { + return true; + } + // Possible states Return value + // kLoadSucceeded/kRangeSucceeded/kRangeFailed true + // kError (if opts_.permissive) true + // kError (if !opts_.permissive) false + // kEof (if script_input_.Close() && !opts.permissive) false + // kEof (if !script_input_.Close() || opts.permissive) true + // kUninitialized/kFileStart/kHaveScpLine true + // kUnitialized true + } + + virtual ~SequentialTableReaderScriptImpl() { + if (this->IsOpen() && !Close()) + KALDI_ERR << "TableReader: reading script file failed: from scp " + << PrintableRxfilename(script_rxfilename_); + } + private: + + // Function EnsureObjectLoaded() ensures that we have fully loaded any object + // (including object range) associated with the current key, and returns true + // on success (i.e. we have the object) and false on failure. + // + // Possible entry states: kHaveScpLine, kLoadSucceeded, kRangeSucceeded + // + // Possible exit states: kHaveScpLine, kLoadSucceeded, kRangeSucceeded. + // + // Note: the return status has information that cannot be deduced from + // just the exit state. If the object could not be loaded we go to state + // kHaveScpLine but return false; and if the range was requested but + // could not be extracted, we go to state kLoadSucceeded but return false. + bool EnsureObjectLoaded() { + if (!(state_ == kHaveScpLine || state_ == kHaveObject || + state_ == kHaveRange)) + KALDI_ERR << "Invalid state (code error)"; + + if (state_ == kHaveScpLine) { // need to load the object into holder_. + bool ans; + // note, NULL means it doesn't read the binary-mode header + if (Holder::IsReadInBinary()) { + ans = data_input_.Open(data_rxfilename_, NULL); + } else { + ans = data_input_.OpenTextMode(data_rxfilename_); + } + if (!ans) { + KALDI_WARN << "Failed to open file " + << PrintableRxfilename(data_rxfilename_); + return false; + } else { + if (holder_.Read(data_input_.Stream())) { + state_ = kHaveObject; + } else { // holder_ will not contain data. + KALDI_WARN << "Failed to load object from " + << PrintableRxfilename(data_rxfilename_); + return false; + } + } + } + // OK, at this point the state must be either + // kHaveObject or kHaveRange. + if (range_.empty()) { + // if range_ is the empty string, we should not be in the state + // kHaveRange. + KALDI_ASSERT(state_ == kHaveObject); + return true; + } + // range_ is nonempty. + if (state_ == kHaveRange) { + // range was already extracted, so there nothing to do. + return true; + } + // OK, range_ is nonempty and state_ is kHaveObject. We attempt to extract + // the range object. Note: ExtractRange() will throw with KALDI_ERR if the + // object type doesn't support ranges. + if (!range_holder_.ExtractRange(holder_, range_)) { + KALDI_WARN << "Failed to load object from " + << PrintableRxfilename(data_rxfilename_) + << "[" << range_ << "]"; + return false; + } else { + state_ = kHaveRange; + return true; + } + } + + void SetErrorState() { + state_ = kError; + script_input_.Close(); + data_input_.Close(); + holder_.Clear(); + range_holder_.Clear(); + } + + // Reads the next line in the script file. + // Possible entry states: kHaveObject, kHaveRange, kHaveScpLine, kFileStart. + // Possible exit states: kEof, kError, kHaveScpLine, kHaveObject. + void NextScpLine() { + switch (state_) { // Check and simplify the state. + case kHaveRange: + range_holder_.Clear(); + state_ = kHaveObject; + break; + case kHaveScpLine: case kHaveObject: case kFileStart: break; + default: + // No other states are valid to call Next() from. + KALDI_ERR << "Reading script file: Next called wrongly."; + } + // at this point the state will be kHaveObject, kHaveScpLine, or kFileStart. + std::string line; + if (getline(script_input_.Stream(), line)) { + // After extracting "key" from "line", we put the rest + // of "line" into "rest", and then extract data_rxfilename_ + // (e.g. 1.ark:100) and possibly the range_ specifer + // (e.g. [1:2,2:10]) from "rest". + std::string data_rxfilename, rest; + SplitStringOnFirstSpace(line, &key_, &rest); + if (!key_.empty() && !rest.empty()) { + // Got a valid line. + if (rest[rest.size()-1] == ']') { + if(!ExtractRangeSpecifier(rest, &data_rxfilename, &range_)) { + KALDI_WARN << "Reading rspecifier '" << rspecifier_ + << ", cannot make sense of scp line " + << line; + SetErrorState(); + return; + } + } else { + data_rxfilename = rest; + range_ = ""; + } + bool filenames_equal = (data_rxfilename_ == data_rxfilename); + if (!filenames_equal) + data_rxfilename_ = data_rxfilename; + if (state_ == kHaveObject) { + if (!filenames_equal) { + holder_.Clear(); + state_ = kHaveScpLine; + } + // else leave state_ at kHaveObject and leave the object in the + // holder. + } else { + state_ = kHaveScpLine; + } + } else { + KALDI_WARN << "We got an invalid line in the scp file. " + << "It should look like: some_key 1.ark:10, got: " + << line; + SetErrorState(); + } + } else { + state_ = kEof; // there is nothing more in the scp file. Might as well + // close input streams as we don't need them. + script_input_.Close(); + if (data_input_.IsOpen()) + data_input_.Close(); + holder_.Clear(); // clear the holder if it was nonempty. + range_holder_.Clear(); // clear the range holder if it was nonempty. + } + } + + std::string rspecifier_; // the rspecifier that this class was opened with. + RspecifierOptions opts_; // options. + std::string script_rxfilename_; // rxfilename of the script file. + + Input script_input_; // Input object for the .scp file + Input data_input_; // Input object for the entries in the script file; + // we make this a class member instead of a local variable, + // so that rspecifiers of the form filename:byte-offset, + // e.g. foo.ark:12345, can be handled using fseek(). + + Holder holder_; // Holds the object. + Holder range_holder_; // Holds the partial object corresponding to the object + // range specifier 'range_'; this is only used when + // 'range_' is specified, i.e. when the .scp file + // contains lines of the form rspecifier[range], like + // foo.ark:242[0:9] (representing a row range of a + // matrix). + + + std::string key_; // the key of the current scp line we're processing + std::string data_rxfilename_; // the rxfilename corresponding to the current key + std::string range_; // the range of object corresponding to the current key, if an + // object range was specified in the script file, else "". + + enum StateType { + // Summary of the states this object can be in (state_). + // + // (*) Does holder_ contain the object corresponding to + // data_rxfilename_ ? + // (*) Does range_holder_ contain a range object? + // (*) is script_input_ open? + // (*) are key_, data_rxfilename_ and range_ [if applicable] set? + // + kUninitialized, // no no no no Uninitialized or closed object. + kFileStart, // no no yes no We just opened the .scp file (we'll never be in this + // state when a user-visible function is called.) + kEof, // no no no no We did Next() and found eof in script file. + kError, // no no no no Error reading or parsing script file. + kHaveScpLine, // no no yes yes Have a line of the script file but nothing else. + kHaveObject, // yes no yes yes holder_ contains an object but range_holder_ does not. + kHaveRange, // yes yes yes yes we have the range object in range_holder_ (implies + // range_ nonempty). + } state_; + + +}; + + +// This is the implementation for SequentialTableReader +// when it's an archive. Note that the archive format is: +// key1 [space] object1 key2 [space] +// object2 ... eof. +// "object1" is the output of the Holder::Write function and will +// typically contain a binary header (in binary mode) and then +// the output of object.Write(os, binary). +// The archive itself does not care whether it is in binary +// or text mode, for reading purposes. + +template class SequentialTableReaderArchiveImpl: + public SequentialTableReaderImplBase { + public: + typedef typename Holder::T T; + + SequentialTableReaderArchiveImpl(): state_(kUninitialized) { } + + virtual bool Open(const std::string &rspecifier) { + if (state_ != kUninitialized) { + if (!Close()) { // call Close() yourself to suppress this exception. + if (opts_.permissive) + KALDI_WARN << "Error closing previous input " + "(only warning, since permissive mode)."; + else + KALDI_ERR << "Error closing previous input."; + } + } + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, + &archive_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kArchiveRspecifier); + + bool ans; + // NULL means don't expect binary-mode header + if (Holder::IsReadInBinary()) + ans = input_.Open(archive_rxfilename_, NULL); + else + ans = input_.OpenTextMode(archive_rxfilename_); + if (!ans) { // header. + KALDI_WARN << "Failed to open stream " + << PrintableRxfilename(archive_rxfilename_); + state_ = kUninitialized; // Failure on Open + return false; // User should print the error message. + } + state_ = kFileStart; + Next(); + if (state_ == kError) { + KALDI_WARN << "Error beginning to read archive file (wrong filename?): " + << PrintableRxfilename(archive_rxfilename_); + input_.Close(); + state_ = kUninitialized; + return false; + } + KALDI_ASSERT(state_ == kHaveObject || state_ == kEof); + return true; + } + + virtual void Next() { + switch (state_) { + case kHaveObject: + holder_.Clear(); + break; + case kFileStart: case kFreedObject: + break; + default: + KALDI_ERR << "Next() called wrongly."; + } + std::istream &is = input_.Stream(); + is.clear(); // Clear any fail bits that may have been set... just in case + // this happened in the Read function. + is >> key_; // This eats up any leading whitespace and gets the string. + if (is.eof()) { + state_ = kEof; + return; + } + if (is.fail()) { // This shouldn't really happen, barring file-system + // errors. + KALDI_WARN << "Error reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + int c; + if ((c = is.peek()) != ' ' && c != '\t' && c != '\n') { // We expect a + // space ' ' after the key. + // We also allow tab [which is consumed] and newline [which is not], just + // so we can read archives generated by scripts that may not be fully + // aware of how this format works. + KALDI_WARN << "Invalid archive file format: expected space after key " + << key_ << ", got character " + << CharToString(static_cast(is.peek())) << ", reading " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + if (c != '\n') is.get(); // Consume the space or tab. + if (holder_.Read(is)) { + state_ = kHaveObject; + return; + } else { + KALDI_WARN << "Object read failed, reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kEof: case kError: case kHaveObject: case kFreedObject: return true; + case kUninitialized: return false; + default: KALDI_ERR << "IsOpen() called on invalid object."; // kFileStart + // is not valid state for user to call something on. + return false; + } + } + + virtual bool Done() const { + switch (state_) { + case kHaveObject: + return false; + case kEof: case kError: + return true; // Error-state counts as Done(), but destructor + // will fail (unless you check the status with Close()). + default: + KALDI_ERR << "Done() called on TableReader object at the wrong time."; + return false; + } + } + + virtual std::string Key() { + // Valid to call this whenever Done() returns false + switch (state_) { + case kHaveObject: break; // only valid case. + default: + // coding error. + KALDI_ERR << "Key() called on TableReader object at the wrong time."; + } + return key_; + } + + T &Value() { + switch (state_) { + case kHaveObject: + break; // only valid case. + default: + // coding error. + KALDI_ERR << "Value() called on TableReader object at the wrong time."; + } + return holder_.Value(); + } + + virtual void FreeCurrent() { + if (state_ == kHaveObject) { + holder_.Clear(); + state_ = kFreedObject; + } else { + KALDI_WARN << "FreeCurrent called at the wrong time."; + } + } + + void SwapHolder(Holder *other_holder) { + // call Value() to ensure we have a value, and ignore its return value while + // suppressing compiler warnings by casting to void. + (void) Value(); + if (state_ == kHaveObject) { + holder_.Swap(other_holder); + state_ = kFreedObject; + } else { + KALDI_ERR << "SwapHolder called at the wrong time " + "(error related to ',bg' modifier)."; + } + } + + virtual bool Close() { + // To clean up, Close() also closes the Input object if + // it's open. It will succeed if the stream was not in an error state, + // and the Input object isn't in an error state we've found eof in the archive. + if (!this->IsOpen()) + KALDI_ERR << "Close() called on TableReader twice or otherwise wrongly."; + int32 status = 0; + if (input_.IsOpen()) + status = input_.Close(); + if (state_ == kHaveObject) + holder_.Clear(); + StateType old_state = state_; + state_ = kUninitialized; + if (old_state == kError || (old_state == kEof && status != 0)) { + if (opts_.permissive) { + KALDI_WARN << "Error detected closing TableReader for archive " + << PrintableRxfilename(archive_rxfilename_) + << " but ignoring " + << "it as permissive mode specified."; + return true; + } else { + return false; + } + } else { + return true; + } + } + + virtual ~SequentialTableReaderArchiveImpl() { + if (this->IsOpen() && !Close()) + KALDI_ERR << "TableReader: error detected closing archive " + << PrintableRxfilename(archive_rxfilename_); + } + private: + Input input_; // Input object for the archive + Holder holder_; // Holds the object. + std::string key_; + std::string rspecifier_; + std::string archive_rxfilename_; + RspecifierOptions opts_; + enum StateType { // [The state of the reading process] [does holder_ [is input_ + // have object] open] + kUninitialized, // Uninitialized or closed. no no + kFileStart, // [state we use internally: just opened.] no yes + kEof, // We did Next() and found eof in archive no no + kError, // Some other error no no + kHaveObject, // We read the key and the object after it. yes yes + kFreedObject, // The user called FreeCurrent(). no yes + } state_; +}; + +// this is for when someone adds the 'th' modifier; it wraps around the basic +// implementation and allows it to do the reading in a background thread. +template +class SequentialTableReaderBackgroundImpl: + public SequentialTableReaderImplBase { + public: + typedef typename Holder::T T; + + SequentialTableReaderBackgroundImpl( + SequentialTableReaderImplBase *base_reader): + base_reader_(base_reader) {} + + // This function ignores the rxfilename argument. + // We use the same function signature as the regular Open(), + // for convenience. + virtual bool Open(const std::string &rxfilename) { + KALDI_ASSERT(base_reader_ != NULL && + base_reader_->IsOpen()); // or code error. + { + thread_ = std::thread(SequentialTableReaderBackgroundImpl::run, + this); + } + + if (!base_reader_->Done()) + Next(); + return true; + } + + virtual bool IsOpen() const { + // Close() sets base_reader_ to NULL, and we never initialize this object + // with a non-open base_reader_, so no need to check if it's open. + return base_reader_ != NULL; + } + + void RunInBackground() { + try { + // This function is called in the background thread. The whole point of + // the background thread is that we don't want to do the actual reading + // (inside Next()) in the foreground. + while (base_reader_ != NULL && !base_reader_->Done()) { + consumer_sem_.Signal(); + // Here is where the consumer process (parent thread) gets to do its + // stuff. Principally it calls SwapHolder()-- a shallow swap that is + // cheap. + producer_sem_.Wait(); + // we check that base_reader_ is not NULL in case Close() was + // called in the main thread. + if (base_reader_ != NULL) + base_reader_->Next(); // here is where the work happens. + } + // this signal will be waited on in the Next() function of the foreground + // thread if it is still running, or Close() otherwise. + consumer_sem_.Signal(); + // this signal may be waited on in Close(). + consumer_sem_.Signal(); + } catch (...) { + // There is nothing we called above that could potentially throw due to + // user data. So we treat reaching this point as a code-error condition. + // Closing base_reader_ will trigger an exception in Next() in the main + // thread when it checks that base_reader_->IsOpen(). + if (base_reader_->IsOpen()) { + base_reader_->Close(); + delete base_reader_; + base_reader_ = NULL; + } + consumer_sem_.Signal(); + return; + } + } + static void run(SequentialTableReaderBackgroundImpl *object) { + object->RunInBackground(); + } + virtual bool Done() const { + return key_.empty(); + } + virtual std::string Key() { + if (key_.empty()) + KALDI_ERR << "Calling Key() at the wrong time."; + return key_; + } + virtual T &Value() { + if (key_.empty()) + KALDI_ERR << "Calling Value() at the wrong time."; + return holder_.Value(); + } + void SwapHolder(Holder *other_holder) { + KALDI_ERR << "SwapHolder() should not be called on this class."; + } + virtual void FreeCurrent() { + if (key_.empty()) + KALDI_ERR << "Calling FreeCurrent() at the wrong time."; + // note: ideally a call to Value() should crash if you have just called + // FreeCurrent(). For typical holders such as KaldiObjectHolder this will + // happen inside the holder_.Value() call. This won't be the case for all + // holders, but it's not a great loss (just a missed opportunity to spot a + // code error). + holder_.Clear(); + } + virtual void Next() { + consumer_sem_.Wait(); + if (base_reader_ == NULL || !base_reader_->IsOpen()) + KALDI_ERR << "Error detected (likely code error) in background " + << "reader (',bg' option)"; + if (base_reader_->Done()) { + // there is nothing else to read. + key_ = ""; + } else { + key_ = base_reader_->Key(); + base_reader_->SwapHolder(&holder_); + } + // this Signal() tells the producer thread, in the background, + // that it's now safe to read the next value. + producer_sem_.Signal(); + } + + // note: we can be sure that Close() won't be called twice, as the TableReader + // object will delete this object after calling Close. + virtual bool Close() { + KALDI_ASSERT(base_reader_ != NULL && thread_.joinable()); + // wait until the producer thread is idle. + consumer_sem_.Wait(); + bool ans = true; + try { + ans = base_reader_->Close(); + } catch (...) { + ans = false; + } + delete base_reader_; + // setting base_reader_ to NULL will cause the loop in the producer thread + // to exit. + base_reader_ = NULL; + producer_sem_.Signal(); + + thread_.join(); + return ans; + } + ~SequentialTableReaderBackgroundImpl() { + if (base_reader_) { + if (!Close()) { + KALDI_ERR << "Error detected closing background reader " + << "(relates to ',bg' modifier)"; + } + } + } + private: + std::string key_; + Holder holder_; + // I couldn't figure out what to call these semaphores. consumer_sem_ is the + // one that the consumer (main thread) waits on; producer_sem_ is the one + // that the producer (background thread) waits on. + Semaphore consumer_sem_; + Semaphore producer_sem_; + std::thread thread_; + SequentialTableReaderImplBase *base_reader_; + +}; + +template +SequentialTableReader::SequentialTableReader(const std::string + &rspecifier): impl_(NULL) { + if (rspecifier != "" && !Open(rspecifier)) + KALDI_ERR << "Error constructing TableReader: rspecifier is " << rspecifier; +} + +template +bool SequentialTableReader::Open(const std::string &rspecifier) { + if (IsOpen()) + if (!Close()) + KALDI_ERR << "Could not close previously open object."; + // now impl_ will be NULL. + + RspecifierOptions opts; + RspecifierType wt = ClassifyRspecifier(rspecifier, NULL, &opts); + switch (wt) { + case kArchiveRspecifier: + impl_ = new SequentialTableReaderArchiveImpl(); + break; + case kScriptRspecifier: + impl_ = new SequentialTableReaderScriptImpl(); + break; + case kNoRspecifier: default: + KALDI_WARN << "Invalid rspecifier " << rspecifier; + return false; + } + if (!impl_->Open(rspecifier)) { + delete impl_; + impl_ = NULL; + return false; // sub-object will have printed warnings. + } + if (opts.background) { + impl_ = new SequentialTableReaderBackgroundImpl( + impl_); + if (!impl_->Open("")) { + // the rxfilename is ignored in that Open() call. + // It should only return false on code error. + return false; + } + } + return true; +} + +template +bool SequentialTableReader::Close() { + CheckImpl(); + bool ans = impl_->Close(); + delete impl_; // We don't keep around empty impl_ objects. + impl_ = NULL; + return ans; +} + + +template +bool SequentialTableReader::IsOpen() const { + return (impl_ != NULL); // Because we delete the object whenever + // that object is not open. Thus, the IsOpen functions of the + // Impl objects are not really needed. +} + +template +std::string SequentialTableReader::Key() { + CheckImpl(); + return impl_->Key(); // this call may throw if called wrongly in other ways, + // e.g. eof. +} + + +template +void SequentialTableReader::FreeCurrent() { + CheckImpl(); + impl_->FreeCurrent(); +} + + +template +typename SequentialTableReader::T & +SequentialTableReader::Value() { + CheckImpl(); + return impl_->Value(); // This may throw (if EnsureObjectLoaded() returned false you + // are safe.). +} + + +template +void SequentialTableReader::Next() { + CheckImpl(); + impl_->Next(); +} + +template +bool SequentialTableReader::Done() { + CheckImpl(); + return impl_->Done(); +} + + +template +SequentialTableReader::~SequentialTableReader() { + delete impl_; + // Destructor of impl_ may throw. +} + + + +template class TableWriterImplBase { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &wspecifier) = 0; + + // Write returns true on success, false on failure, but + // some errors may not be detected until we call Close(). + // It throws (via KALDI_ERR) if called wrongly. We could + // have just thrown on all errors, since this is what + // TableWriter does; it was designed this way because originally + // TableWriter::Write returned an exit status. + virtual bool Write(const std::string &key, const T &value) = 0; + + // Flush will flush any archive; it does not return error status, + // any errors will be reported on the next Write or Close. + virtual void Flush() = 0; + + virtual bool Close() = 0; + + virtual bool IsOpen() const = 0; + + // May throw on write error if Close was not called. + virtual ~TableWriterImplBase() { } + + TableWriterImplBase() { } + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(TableWriterImplBase); +}; + + +// The implementation of TableWriter we use when writing directly +// to an archive with no associated scp. +template +class TableWriterArchiveImpl: public TableWriterImplBase { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &wspecifier) { + switch (state_) { + case kUninitialized: + break; + case kWriteError: + KALDI_ERR << "Opening stream, already open with write error."; + case kOpen: default: + if (!Close()) // throw because this error may not have been previously + // detected by the user. + KALDI_ERR << "Opening stream, error closing previously open stream."; + } + wspecifier_ = wspecifier; + WspecifierType ws = ClassifyWspecifier(wspecifier, + &archive_wxfilename_, + NULL, + &opts_); + KALDI_ASSERT(ws == kArchiveWspecifier); // or wrongly called. + + if (output_.Open(archive_wxfilename_, opts_.binary, false)) { // false + // means no binary header. + state_ = kOpen; + return true; + } else { + // stream will not be open. User will report this error + // (we return bool), so don't bother printing anything. + state_ = kUninitialized; + return false; + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kUninitialized: return false; + case kOpen: case kWriteError: return true; + default: KALDI_ERR << "IsOpen() called on TableWriter in invalid state."; + } + return false; + } + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual bool Write(const std::string &key, const T &value) { + switch (state_) { + case kOpen: break; + case kWriteError: + // user should have known from the last + // call to Write that there was a problem. + KALDI_WARN << "Attempting to write to invalid stream."; + return false; + case kUninitialized: default: + KALDI_ERR << "Write called on invalid stream"; + } + // state is now kOpen or kWriteError. + if (!IsToken(key)) // e.g. empty string or has spaces... + KALDI_ERR << "Using invalid key " << key; + output_.Stream() << key << ' '; + if (!Holder::Write(output_.Stream(), opts_.binary, value)) { + KALDI_WARN << "Write failure to " + << PrintableWxfilename(archive_wxfilename_); + state_ = kWriteError; + return false; + } + if (state_ == kWriteError) return false; // Even if this Write seems to + // have succeeded, we fail because a previous Write failed and the archive + // may be corrupted and unreadable. + + if (opts_.flush) + Flush(); + return true; + } + + // Flush will flush any archive; it does not return error status, + // any errors will be reported on the next Write or Close. + virtual void Flush() { + switch (state_) { + case kWriteError: case kOpen: + output_.Stream().flush(); // Don't check error status. + return; + default: + KALDI_WARN << "Flush called on not-open writer."; + } + } + + virtual bool Close() { + if (!this->IsOpen() || !output_.IsOpen()) + KALDI_ERR << "Close called on a stream that was not open." + << this->IsOpen() << ", " << output_.IsOpen(); + bool close_success = output_.Close(); + if (!close_success) { + KALDI_WARN << "Error closing stream: wspecifier is " << wspecifier_; + state_ = kUninitialized; + return false; + } + if (state_ == kWriteError) { + KALDI_WARN << "Closing writer in error state: wspecifier is " + << wspecifier_; + state_ = kUninitialized; + return false; + } + state_ = kUninitialized; + return true; + } + + TableWriterArchiveImpl(): state_(kUninitialized) {} + + // May throw on write error if Close was not called. + virtual ~TableWriterArchiveImpl() { + if (!IsOpen()) return; + else if (!Close()) + KALDI_ERR << "At TableWriter destructor: Write failed or stream close " + << "failed: wspecifier is "<< wspecifier_; + } + + private: + Output output_; + WspecifierOptions opts_; + std::string wspecifier_; + std::string archive_wxfilename_; + enum { // is stream open? + kUninitialized, // no + kOpen, // yes + kWriteError, // yes + } state_; +}; + + + + +// The implementation of TableWriter we use when writing to +// individual files (more generally, wxfilenames) specified +// in an scp file that we read. + +// Note: the code for this class is similar to +// RandomAccessTableReaderScriptImpl; try to keep them in sync. + +template +class TableWriterScriptImpl: public TableWriterImplBase { + public: + typedef typename Holder::T T; + + TableWriterScriptImpl(): last_found_(0), state_(kUninitialized) {} + + virtual bool Open(const std::string &wspecifier) { + switch (state_) { + case kReadScript: + KALDI_ERR << " Opening already open TableWriter: call Close first."; + case kUninitialized: case kNotReadScript: + break; + } + wspecifier_ = wspecifier; + WspecifierType ws = ClassifyWspecifier(wspecifier, + NULL, + &script_rxfilename_, + &opts_); + KALDI_ASSERT(ws == kScriptWspecifier); // or wrongly called. + KALDI_ASSERT(script_.empty()); // no way it could be nonempty at this point. + + if (!ReadScriptFile(script_rxfilename_, + true, // print any warnings + &script_)) { // error reading script file or invalid + // format + state_ = kNotReadScript; + return false; // no need to print further warnings. user gets the error. + } + std::sort(script_.begin(), script_.end()); + for (size_t i = 0; i+1 < script_.size(); i++) { + if (script_[i].first.compare(script_[i+1].first) >= 0) { + // script[i] not < script[i+1] in lexical order... + KALDI_WARN << "Script file " << PrintableRxfilename(script_rxfilename_) + << " contains duplicate key " << script_[i].first; + state_ = kNotReadScript; + return false; + } + } + state_ = kReadScript; + return true; + } + + virtual bool IsOpen() const { return (state_ == kReadScript); } + + virtual bool Close() { + if (!IsOpen()) + KALDI_ERR << "Close() called on TableWriter that was not open."; + state_ = kUninitialized; + last_found_ = 0; + script_.clear(); + return true; + } + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual bool Write(const std::string &key, const T &value) { + if (!IsOpen()) + KALDI_ERR << "Write called on invalid stream"; + + if (!IsToken(key)) // e.g. empty string or has spaces... + KALDI_ERR << "Using invalid key " << key; + + std::string wxfilename; + if (!LookupFilename(key, &wxfilename)) { + if (opts_.permissive) { + return true; // In permissive mode, it's as if we're writing to + // /dev/null for missing keys. + } else { + KALDI_WARN << "Script file " + << PrintableRxfilename(script_rxfilename_) + << " has no entry for key " < pr(key, ""); // Important that "" + // compares less than or equal to any string, so lower_bound points to the + // element that has the same key. + typedef typename std::vector > + ::const_iterator IterType; + IterType iter = std::lower_bound(script_.begin(), script_.end(), pr); + if (iter != script_.end() && iter->first == key) { + last_found_ = iter - script_.begin(); + *wxfilename = iter->second; + return true; + } else { + return false; + } + } + + + WspecifierOptions opts_; + std::string wspecifier_; + std::string script_rxfilename_; + + // the script_ variable contains pairs of (key, filename), sorted using + // std::sort. This can be used with binary_search to look up filenames for + // writing. If this becomes inefficient we can use std::unordered_map (but I + // suspect this wouldn't be significantly faster & would use more memory). + // If memory becomes a problem here, the user should probably be passing + // only the relevant part of the scp file rather than expecting us to get too + // clever in the code. + std::vector > script_; + size_t last_found_; // This is for an optimization used in LookupFilename. + + enum { + kUninitialized, + kReadScript, + kNotReadScript, // read of script failed. + } state_; +}; + + +// The implementation of TableWriter we use when writing directly +// to an archive plus an associated scp. +template +class TableWriterBothImpl: public TableWriterImplBase { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &wspecifier) { + switch (state_) { + case kUninitialized: + break; + case kWriteError: + KALDI_ERR << "Opening stream, already open with write error."; + case kOpen: default: + if (!Close()) // throw because this error may not have been previously + // detected by user. + KALDI_ERR << "Opening stream, error closing previously open stream."; + } + wspecifier_ = wspecifier; + WspecifierType ws = ClassifyWspecifier(wspecifier, + &archive_wxfilename_, + &script_wxfilename_, + &opts_); + KALDI_ASSERT(ws == kBothWspecifier); // or wrongly called. + if (ClassifyWxfilename(archive_wxfilename_) != kFileOutput) + KALDI_WARN << "When writing to both archive and script, the script file " + "will generally not be interpreted correctly unless the archive is " + "an actual file: wspecifier = " << wspecifier; + + if (!archive_output_.Open(archive_wxfilename_, opts_.binary, false)) { + // false means no binary header. + state_ = kUninitialized; + return false; + } + if (!script_output_.Open(script_wxfilename_, false, false)) { // first + // false means text mode: script files always text-mode. second false + // means don't write header (doesn't matter for text mode). + archive_output_.Close(); // Don't care about status: error anyway. + state_ = kUninitialized; + return false; + } + state_ = kOpen; + return true; + } + + virtual bool IsOpen() const { + switch (state_) { + case kUninitialized: return false; + case kOpen: case kWriteError: return true; + default: KALDI_ERR << "IsOpen() called on TableWriter in invalid state."; + } + return false; + } + + void MakeFilename(typename std::ostream::pos_type streampos, + std::string *output) const { + std::ostringstream ss; + ss << ':' << streampos; + KALDI_ASSERT(ss.str() != ":-1"); + *output = archive_wxfilename_ + ss.str(); + + // e.g. /some/file:12302. + // Note that we warned if archive_wxfilename_ is not an actual filename; + // the philosophy is we give the user rope and if they want to hang + // themselves, with it, fine. + } + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual bool Write(const std::string &key, const T &value) { + switch (state_) { + case kOpen: break; + case kWriteError: + // user should have known from the last + // call to Write that there was a problem. Warn about it. + KALDI_WARN << "Writing to non-open TableWriter object."; + return false; + case kUninitialized: default: + KALDI_ERR << "Write called on invalid stream"; + } + // state is now kOpen or kWriteError. + if (!IsToken(key)) // e.g. empty string or has spaces... + KALDI_ERR << "Using invalid key " << key; + std::ostream &archive_os = archive_output_.Stream(); + archive_os << key << ' '; + typename std::ostream::pos_type archive_os_pos = archive_os.tellp(); + // position at start of Write() to archive. We will record this in the + // script file. + std::string offset_rxfilename; // rxfilename with offset into the archive, + // e.g. some_archive_name.ark:431541423 + MakeFilename(archive_os_pos, &offset_rxfilename); + + // Write to the script file first. + // The idea is that we want to get all the information possible into the + // script file, to make it easier to unwind errors later. + std::ostream &script_os = script_output_.Stream(); + script_output_.Stream() << key << ' ' << offset_rxfilename << '\n'; + + if (!Holder::Write(archive_output_.Stream(), opts_.binary, value)) { + KALDI_WARN << "Write failure to" + << PrintableWxfilename(archive_wxfilename_); + state_ = kWriteError; + return false; + } + + if (script_os.fail()) { + KALDI_WARN << "Write failure to script file detected: " + << PrintableWxfilename(script_wxfilename_); + state_ = kWriteError; + return false; + } + + if (archive_os.fail()) { + KALDI_WARN << "Write failure to archive file detected: " + << PrintableWxfilename(archive_wxfilename_); + state_ = kWriteError; + return false; + } + + if (state_ == kWriteError) return false; // Even if this Write seems to + // have succeeded, we fail because a previous Write failed and the archive + // may be corrupted and unreadable. + + if (opts_.flush) + Flush(); + return true; + } + + // Flush will flush any archive; it does not return error status, + // any errors will be reported on the next Write or Close. + virtual void Flush() { + switch (state_) { + case kWriteError: case kOpen: + archive_output_.Stream().flush(); // Don't check error status. + script_output_.Stream().flush(); // Don't check error status. + return; + default: + KALDI_WARN << "Flush called on not-open writer."; + } + } + + virtual bool Close() { + if (!this->IsOpen()) + KALDI_ERR << "Close called on a stream that was not open."; + bool close_success = true; + if (archive_output_.IsOpen()) + if (!archive_output_.Close()) close_success = false; + if (script_output_.IsOpen()) + if (!script_output_.Close()) close_success = false; + bool ans = close_success && (state_ != kWriteError); + state_ = kUninitialized; + return ans; + } + + TableWriterBothImpl(): state_(kUninitialized) {} + + // May throw on write error if Close() was not called. + // User can get the error status by calling Close(). + virtual ~TableWriterBothImpl() { + if (!IsOpen()) return; + else if (!Close()) + KALDI_ERR << "Write failed or stream close failed: " + << wspecifier_; + } + + private: + Output archive_output_; + Output script_output_; + WspecifierOptions opts_; + std::string archive_wxfilename_; + std::string script_wxfilename_; + std::string wspecifier_; + enum { // is stream open? + kUninitialized, // no + kOpen, // yes + kWriteError, // yes + } state_; +}; + + +template +TableWriter::TableWriter(const std::string &wspecifier): impl_(NULL) { + if (wspecifier != "" && !Open(wspecifier)) + KALDI_ERR << "Failed to open table for writing with wspecifier: " << wspecifier + << ": errno (in case it's relevant) is: " << strerror(errno); +} + +template +bool TableWriter::IsOpen() const { + return (impl_ != NULL); +} + + +template +bool TableWriter::Open(const std::string &wspecifier) { + if (IsOpen()) { + if (!Close()) // call Close() yourself to suppress this exception. + KALDI_ERR << "Failed to close previously open writer."; + } + KALDI_ASSERT(impl_ == NULL); + WspecifierType wtype = ClassifyWspecifier(wspecifier, NULL, NULL, NULL); + switch (wtype) { + case kBothWspecifier: + impl_ = new TableWriterBothImpl(); + break; + case kArchiveWspecifier: + impl_ = new TableWriterArchiveImpl(); + break; + case kScriptWspecifier: + impl_ = new TableWriterScriptImpl(); + break; + case kNoWspecifier: default: + KALDI_WARN << "ClassifyWspecifier: invalid wspecifier " << wspecifier; + return false; + } + if (impl_->Open(wspecifier)) { + return true; + } else { // The class will have printed a more specific warning. + delete impl_; + impl_ = NULL; + return false; + } +} + +template +void TableWriter::Write(const std::string &key, + const T &value) const { + CheckImpl(); + if (!impl_->Write(key, value)) + KALDI_ERR << "Error in TableWriter::Write"; + // More specific warning will have + // been printed in the Write function. +} + +template +void TableWriter::Flush() { + CheckImpl(); + impl_->Flush(); +} + +template +bool TableWriter::Close() { + CheckImpl(); + bool ans = impl_->Close(); + delete impl_; // We don't keep around non-open impl_ objects + // [c.f. definition of IsOpen()] + impl_ = NULL; + return ans; +} + +template +TableWriter::~TableWriter() { + if (IsOpen() && !Close()) { + KALDI_ERR << "Error closing TableWriter [in destructor]."; + } +} + + +// Types of RandomAccessTableReader: +// In principle, we would like to have four types of RandomAccessTableReader: +// the 4 combinations [scp, archive], [seekable, not-seekable], +// where if something is seekable we only store a file offset. However, +// it seems sufficient for now to only implement two of these, in both +// cases assuming it's not seekable so we never store file offsets and always +// store either the scp line or the data in the archive. The reasons are: +// (1) +// For scp files, storing the actual entry is not that much more expensive +// than storing the file offsets (since the entries are just filenames), and +// avoids a lot of fseek operations that might be expensive. +// (2) +// For archive files, there is no real reason, if you have the archive file +// on disk somewhere, why you wouldn't access it via its associated scp. +// [i.e. write it as ark, scp]. The main reason to read archives directly +// is if they are part of a pipe, and in this case it's not seekable, so +// we implement only this case. +// +// Note that we will rarely in practice have to keep in memory everything in +// the archive, as long as things are only read once from the archive (the +// "o, " or "once" option) and as long as we keep our keys in sorted order; +// to take advantage of this we need the "s, " (sorted) option, so we would +// read archives as e.g. "s, o, ark:-" (this is the rspecifier we would use if +// it was the standard input and these conditions held). + +template class RandomAccessTableReaderImplBase { + public: + typedef typename Holder::T T; + + virtual bool Open(const std::string &rspecifier) = 0; + + virtual bool HasKey(const std::string &key) = 0; + + virtual const T &Value(const std::string &key) = 0; + + virtual bool Close() = 0; + + virtual ~RandomAccessTableReaderImplBase() {} +}; + + +// Implementation of RandomAccessTableReader for a script file; for simplicity +// we just read it in all in one go, as it's unlikely someone would generate +// this from a pipe. In principle we could read it on-demand as for the +// archives, but this would probably be overkill. + +// Note: the code for this this class is similar to TableWriterScriptImpl: +// try to keep them in sync. +template +class RandomAccessTableReaderScriptImpl: + public RandomAccessTableReaderImplBase { + public: + typedef typename Holder::T T; + + RandomAccessTableReaderScriptImpl(): last_found_(0), state_(kUninitialized) {} + + virtual bool Open(const std::string &rspecifier) { + switch (state_) { + case kNotHaveObject: case kHaveObject: case kHaveRange: + KALDI_ERR << " Opening already open RandomAccessTableReader:" + " call Close first."; + case kUninitialized: case kNotReadScript: + break; + } + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, + &script_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kScriptRspecifier); // or wrongly called. + KALDI_ASSERT(script_.empty()); // no way it could be nonempty at this point + + if (!ReadScriptFile(script_rxfilename_, + true, // print any warnings + &script_)) { // error reading script file or invalid + // format + state_ = kNotReadScript; + return false; // no need to print further warnings. user gets the error. + } + + rspecifier_ = rspecifier; + // If opts_.sorted, the user has asserted that the keys are already sorted. + // Although we could easily sort them, we want to let the user know of this + // mistake. This same mistake could have serious effects if used with an + // archive rather than a script. + if (!opts_.sorted) + std::sort(script_.begin(), script_.end()); + for (size_t i = 0; i + 1 < script_.size(); i++) { + if (script_[i].first.compare(script_[i+1].first) >= 0) { + // script[i] not < script[i+1] in lexical order... + bool same = (script_[i].first == script_[i+1].first); + KALDI_WARN << "Script file " << PrintableRxfilename(script_rxfilename_) + << (same ? " contains duplicate key: " : + " is not sorted (remove s, option or add ns, option):" + " key is ") << script_[i].first; + state_ = kNotReadScript; + return false; + } + } + state_ = kNotHaveObject; + key_ = ""; // make sure we don't have a key set + return true; + } + + virtual bool IsOpen() const { + return (state_ == kNotHaveObject || state_ == kHaveObject || + state_ == kHaveRange); + } + + virtual bool Close() { + if (!IsOpen()) + KALDI_ERR << "Close() called on RandomAccessTableReader that was not" + " open."; + holder_.Clear(); + range_holder_.Clear(); + state_ = kUninitialized; + last_found_ = 0; + script_.clear(); + key_ = ""; + range_ = ""; + data_rxfilename_ = ""; + // This cannot fail because any errors of a "global" nature would have been + // detected when we did Open(). With archives it's different. + return true; + } + + virtual bool HasKey(const std::string &key) { + bool preload = opts_.permissive; + // In permissive mode, we have to check that we can read + // the scp entry before we assert that the key is there. + return HasKeyInternal(key, preload); + } + + + // Write returns true on success, false on failure, but + // some errors may not be detected till we call Close(). + virtual const T& Value(const std::string &key) { + if (!HasKeyInternal(key, true)) // true == preload. + KALDI_ERR << "Could not get item for key " << key + << ", rspecifier is " << rspecifier_ << " [to ignore this, " + << "add the p, (permissive) option to the rspecifier."; + KALDI_ASSERT(key_ == key); + if (state_ == kHaveObject) { + return holder_.Value(); + } else { + KALDI_ASSERT(state_ == kHaveRange); + return range_holder_.Value(); + } + } + + virtual ~RandomAccessTableReaderScriptImpl() { } + + private: + + // HasKeyInternal when called with preload == false just tells us whether the + // key is in the scp. With preload == true, which happens when the ,p + // (permissive) option is given in the rspecifier (or when called from + // Value()), it will also check that we can preload the object from disk + // (loading from the rxfilename in the scp), and only return true if we can. + // This function is called both from HasKey and from Value(). + virtual bool HasKeyInternal(const std::string &key, bool preload) { + switch (state_) { + case kUninitialized: case kNotReadScript: + KALDI_ERR << "HasKey called on RandomAccessTableReader object that is" + " not open."; + case kHaveObject: + if (key == key_ && range_.empty()) + return true; + break; + case kHaveRange: + if (key == key_) + return true; + break; + case kNotHaveObject: default: break; + } + KALDI_ASSERT(IsToken(key)); + size_t key_pos = 0; + if (!LookupKey(key, &key_pos)) { + return false; + } else { + if (!preload) { + return true; // we have the key, and were not asked to verify that the + // object could be read. + } else { // preload specified, so we have to attempt to pre-load the + // object before returning. + std::string data_rxfilename, range; // We will split + // script_[key_pos].second (e.g. "1.ark:100[0:2]" into data_rxfilename + // (e.g. "1.ark:100") and range (if any), e.g. "0:2". + if (script_[key_pos].second[script_[key_pos].second.size()-1] == ']') { + if(!ExtractRangeSpecifier(script_[key_pos].second, + &data_rxfilename, + &range)) { + KALDI_ERR << "TableReader: failed to parse range in '" + << script_[key_pos].second << "'"; + } + } else { + data_rxfilename = script_[key_pos].second; + } + if (state_ == kHaveRange) { + if (data_rxfilename_ == data_rxfilename && range_ == range) { + // the odd situation where two keys had the same rxfilename and range: + // just change the key and keep the object. + key_ = key; + return true; + } else { + range_holder_.Clear(); + state_ = kHaveObject; + } + } + // OK, at this point the state will be kHaveObject or kNotHaveObject. + if (state_ == kHaveObject) { + if (data_rxfilename_ != data_rxfilename) { + // clear out the object. + state_ = kNotHaveObject; + holder_.Clear(); + } + } + // At this point we can safely switch to the new key, data_rxfilename + // and range, and we know that if we have an object, it will already be + // the correct one. The state is now kHaveObject or kNotHaveObject. + key_ = key; + data_rxfilename_ = data_rxfilename; + range_ = range; + if (state_ == kNotHaveObject) { + // we need to read the object. + if (!input_.Open(data_rxfilename)) { + KALDI_WARN << "Error opening stream " + << PrintableRxfilename(data_rxfilename); + return false; + } else { + if (holder_.Read(input_.Stream())) { + state_ = kHaveObject; + } else { + KALDI_WARN << "Error reading object from " + "stream " << PrintableRxfilename(data_rxfilename); + return false; + } + } + } + // At this point the state is kHaveObject. + if (range.empty()) + return true; // we're done: no range was requested. + if (range_holder_.ExtractRange(holder_, range)) { + state_ = kHaveRange; + return true; + } else { + KALDI_WARN << "Failed to load object from " + << PrintableRxfilename(data_rxfilename) + << "[" << range << "]"; + // leave state at kHaveObject. + return false; + } + } + } + } + + // This function attempts to look up the key "key" in the sorted array + // script_. If it was found it returns true and puts the array offset into + // 'script_offset'; otherwise it returns false. + bool LookupKey(const std::string &key, size_t *script_offset) { + // First, an optimization: if we're going consecutively, this will + // make the lookup very fast. Since we may call HasKey and then + // Value(), which both may look up the key, we test if either the + // current or next position are correct. + if (last_found_ < script_.size() && script_[last_found_].first == key) { + *script_offset = last_found_; + return true; + } + last_found_++; + if (last_found_ < script_.size() && script_[last_found_].first == key) { + *script_offset = last_found_; + return true; + } + std::pair pr(key, ""); // Important that "" + // compares less than or equal to any string, so lower_bound points to the + // element that has the same key. + typedef typename std::vector > + ::const_iterator IterType; + IterType iter = std::lower_bound(script_.begin(), script_.end(), pr); + if (iter != script_.end() && iter->first == key) { + last_found_ = *script_offset = iter - script_.begin(); + return true; + } else { + return false; + } + } + + + Input input_; // Use the same input_ object for reading each file, in case + // the scp specifies offsets in an archive so we can keep the + // same file open. + RspecifierOptions opts_; + std::string rspecifier_; // rspecifier used to open this object; used in + // debug messages + std::string script_rxfilename_; // rxfilename of script file that we read. + + std::string key_; // The current key of the object that we have, but see the + // notes regarding states_ for more explanation of the + // semantics. + + Holder holder_; + Holder range_holder_; // Holds the partial object corresponding to the object + // range specifier 'range_'. this is only used when + // 'range_' is specified. + std::string range_; // range within which we read the object from holder_. + // If key_ is set, always correspond to the key. + std::string data_rxfilename_; // the rxfilename corresponding to key_, + // always set when key_ is set. + + + // the script_ variable contains pairs of (key, filename), sorted using + // std::sort. This can be used with binary_search to look up filenames for + // writing. If this becomes inefficient we can use std::unordered_map (but I + // suspect this wouldn't be significantly faster & would use more memory). + // If memory becomes a problem here, the user should probably be passing + // only the relevant part of the scp file rather than expecting us to get too + // clever in the code. + std::vector > script_; + size_t last_found_; // This is for an optimization used in FindFilename. + + enum { + // (*) is script_ set up? + // (*) does holder_ contain an object? + // (*) does range_holder_ contain and object? + // + // + kUninitialized, // no no no + kNotReadScript, // no no no + kNotHaveObject, // yes no no + kHaveObject, // yes yes no + kHaveRange, // yes yes yes + + // If we are in a state where holder_ contains an object, it always contains + // the object from 'key_', and the corresponding rxfilename is always + // 'data_rxfilename_'. If range_holder_ contains an object, it always + // corresponds to the range 'range_' of the object in 'holder_', and always + // corresponds to the current key. + } state_; +}; + + + + +// This is the base-class (with some implemented functions) for the +// implementations of RandomAccessTableReader when it's an archive. This +// base-class handles opening the files, storing the state of the reading +// process, and loading objects. This is the only case in which we have +// an intermediate class in the hierarchy between the virtual ImplBase +// class and the actual Impl classes. +// The child classes vary in the assumptions regarding sorting, etc. + +template +class RandomAccessTableReaderArchiveImplBase: + public RandomAccessTableReaderImplBase { + public: + typedef typename Holder::T T; + + RandomAccessTableReaderArchiveImplBase(): holder_(NULL), + state_(kUninitialized) { } + + virtual bool Open(const std::string &rspecifier) { + if (state_ != kUninitialized) { + if (!this->Close()) // call Close() yourself to suppress this exception. + KALDI_ERR << "Error closing previous input."; + } + rspecifier_ = rspecifier; + RspecifierType rs = ClassifyRspecifier(rspecifier, &archive_rxfilename_, + &opts_); + KALDI_ASSERT(rs == kArchiveRspecifier); + + // NULL means don't expect binary-mode header + bool ans; + if (Holder::IsReadInBinary()) + ans = input_.Open(archive_rxfilename_, NULL); + else + ans = input_.OpenTextMode(archive_rxfilename_); + if (!ans) { // header. + KALDI_WARN << "Failed to open stream " + << PrintableRxfilename(archive_rxfilename_); + state_ = kUninitialized; // Failure on Open + return false; // User should print the error message. + } else { + state_ = kNoObject; + } + return true; + } + + // ReadNextObject() requires that the state be kNoObject, + // and it will try read the next object. If it succeeds, + // it sets the state to kHaveObject, and + // cur_key_ and holder_ have the key and value. If it fails, + // it sets the state to kError or kEof. + void ReadNextObject() { + if (state_ != kNoObject) + KALDI_ERR << "ReadNextObject() called from wrong state."; + // Code error somewhere in this class or a child class. + std::istream &is = input_.Stream(); + is.clear(); // Clear any fail bits that may have been set... just in case + // this happened in the Read function. + is >> cur_key_; // This eats up any leading whitespace and gets the string. + if (is.eof()) { + state_ = kEof; + return; + } + if (is.fail()) { // This shouldn't really happen, barring file-system + // errors. + KALDI_WARN << "Error reading archive: rspecifier is " << rspecifier_; + state_ = kError; + return; + } + int c; + if ((c = is.peek()) != ' ' && c != '\t' && c != '\n') { // We expect a + // space ' ' after the key. + // We also allow tab, just so we can read archives generated by scripts + // that may not be fully aware of how this format works. + KALDI_WARN << "Invalid archive file format: expected space after key " + <(is.peek())) + << ", reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + return; + } + if (c != '\n') is.get(); // Consume the space or tab. + holder_ = new Holder; + if (holder_->Read(is)) { + state_ = kHaveObject; + return; + } else { + KALDI_WARN << "Object read failed, reading archive " + << PrintableRxfilename(archive_rxfilename_); + state_ = kError; + delete holder_; + holder_ = NULL; + return; + } + } + + virtual bool IsOpen() const { + switch (state_) { + case kEof: case kError: case kHaveObject: case kNoObject: return true; + case kUninitialized: return false; + default: KALDI_ERR << "IsOpen() called on invalid object."; + return false; + } + } + + // Called by the child-class virutal Close() functions; does the + // shared parts of the cleanup. + bool CloseInternal() { + if (!this->IsOpen()) + KALDI_ERR << "Close() called on TableReader twice or otherwise wrongly."; + if (input_.IsOpen()) + input_.Close(); + if (state_ == kHaveObject) { + KALDI_ASSERT(holder_ != NULL); + delete holder_; + holder_ = NULL; + } else { + KALDI_ASSERT(holder_ == NULL); + } + bool ans = (state_ != kError); + state_ = kUninitialized; + if (!ans && opts_.permissive) { + KALDI_WARN << "Error state detected closing reader. " + << "Ignoring it because you specified permissive mode."; + return true; + } + return ans; + } + + ~RandomAccessTableReaderArchiveImplBase() { + // The child class has the responsibility to call CloseInternal(). + KALDI_ASSERT(state_ == kUninitialized && holder_ == NULL); + } + private: + Input input_; // Input object for the archive + protected: + // The variables below are accessed by child classes. + + std::string cur_key_; // current key (if state == kHaveObject). + Holder *holder_; // Holds the object we just read (if state == kHaveObject) + + std::string rspecifier_; + std::string archive_rxfilename_; + RspecifierOptions opts_; + + enum { // [The state of the reading process] [does holder_ [is input_ + // have object] open] + kUninitialized, // Uninitialized or closed no no + kNoObject, // Do not have object in holder_ no yes + kHaveObject, // Have object in holder_ yes yes + kEof, // End of file no yes + kError, // Some kind of error-state in the reading. no yes + } state_; +}; + + +// RandomAccessTableReaderDSortedArchiveImpl (DSorted for "doubly sorted") is +// the implementation for random-access reading of archives when both the +// archive, and the calling code, are in sorted order (i.e. we ask for the keys +// in sorted order). This is when the s and cs options are both given. It only +// ever has to keep one object in memory. It inherits from +// RandomAccessTableReaderArchiveImplBase which implements the common parts of +// RandomAccessTableReader that are used when it's an archive we're reading from + +template +class RandomAccessTableReaderDSortedArchiveImpl: + public RandomAccessTableReaderArchiveImplBase { + using RandomAccessTableReaderArchiveImplBase::kUninitialized; + using RandomAccessTableReaderArchiveImplBase::kHaveObject; + using RandomAccessTableReaderArchiveImplBase::kNoObject; + using RandomAccessTableReaderArchiveImplBase::kEof; + using RandomAccessTableReaderArchiveImplBase::kError; + using RandomAccessTableReaderArchiveImplBase::state_; + using RandomAccessTableReaderArchiveImplBase::opts_; + using RandomAccessTableReaderArchiveImplBase::cur_key_; + using RandomAccessTableReaderArchiveImplBase::holder_; + using RandomAccessTableReaderArchiveImplBase::rspecifier_; + using RandomAccessTableReaderArchiveImplBase::archive_rxfilename_; + using RandomAccessTableReaderArchiveImplBase::ReadNextObject; + public: + typedef typename Holder::T T; + + RandomAccessTableReaderDSortedArchiveImpl() { } + + virtual bool Close() { + // We don't have anything additional to clean up, so just + // call generic base-class one. + return this->CloseInternal(); + } + + virtual bool HasKey(const std::string &key) { + return FindKeyInternal(key); + } + virtual const T & Value(const std::string &key) { + if (!FindKeyInternal(key)) { + KALDI_ERR << "Value() called but no such key " << key + << " in archive " << PrintableRxfilename(archive_rxfilename_); + } + KALDI_ASSERT(this->state_ == kHaveObject && key == this->cur_key_ + && holder_ != NULL); + return this->holder_->Value(); + } + + virtual ~RandomAccessTableReaderDSortedArchiveImpl() { + if (this->IsOpen()) + if (!Close()) // more specific warning will already have been printed. + // we are in some kind of error state & user did not find out by + // calling Close(). + KALDI_ERR << "Error closing RandomAccessTableReader: rspecifier is " + << rspecifier_; + } + private: + // FindKeyInternal tries to find the key by calling "ReadNextObject()" + // as many times as necessary till we get to it. It is called from + // both FindKey and Value(). + bool FindKeyInternal(const std::string &key) { + // First check that the user is calling us right: should be + // in sorted order. If not, error. + if (!last_requested_key_.empty()) { + if (key.compare(last_requested_key_) < 0) { // key < last_requested_key_ + KALDI_ERR << "You provided the \"cs\" option " + << "but are not calling with keys in sorted order: " + << key << " < " << last_requested_key_ << ": rspecifier is " + << rspecifier_; + } + } + // last_requested_key_ is just for debugging of order of calling. + last_requested_key_ = key; + + if (state_ == kNoObject) + ReadNextObject(); // This can only happen + // once, the first time someone calls HasKey() or Value(). We don't + // do it in the initializer to stop the program hanging too soon, + // if reading from a pipe. + + if (state_ == kEof || state_ == kError) return false; + + if (state_ == kUninitialized) + KALDI_ERR << "Trying to access a RandomAccessTableReader object that is" + " not open."; + + std::string last_key_; // To check that + // the archive we're reading is in sorted order. + while (1) { + KALDI_ASSERT(state_ == kHaveObject); + int compare = key.compare(cur_key_); + if (compare == 0) { // key == key_ + return true; // we got it.. + } else if (compare < 0) { // key < cur_key_, so we already read past the + // place where we want to be. This implies that we will never find it + // [due to the sorting etc., this means it just isn't in the archive]. + return false; + } else { // compare > 0, key > cur_key_. We need to read further ahead. + last_key_ = cur_key_; + // read next object.. we have to set state to kNoObject first. + KALDI_ASSERT(holder_ != NULL); + delete holder_; + holder_ = NULL; + state_ = kNoObject; + ReadNextObject(); + if (state_ != kHaveObject) + return false; // eof or read error. + if (cur_key_.compare(last_key_) <= 0) { + KALDI_ERR << "You provided the \"s\" option " + << " (sorted order), but keys are out of order or" + " duplicated: " + << last_key_ << " is followed by " << cur_key_ + << ": rspecifier is " << rspecifier_; + } + } + } + } + + /// Last string provided to HasKey() or Value(); + std::string last_requested_key_; +}; + +// RandomAccessTableReaderSortedArchiveImpl is for random-access reading of +// archives when the user specified the sorted (s) option but not the +// called-sorted (cs) options. +template +class RandomAccessTableReaderSortedArchiveImpl: + public RandomAccessTableReaderArchiveImplBase { + using RandomAccessTableReaderArchiveImplBase::kUninitialized; + using RandomAccessTableReaderArchiveImplBase::kHaveObject; + using RandomAccessTableReaderArchiveImplBase::kNoObject; + using RandomAccessTableReaderArchiveImplBase::kEof; + using RandomAccessTableReaderArchiveImplBase::kError; + using RandomAccessTableReaderArchiveImplBase::state_; + using RandomAccessTableReaderArchiveImplBase::opts_; + using RandomAccessTableReaderArchiveImplBase::cur_key_; + using RandomAccessTableReaderArchiveImplBase::holder_; + using RandomAccessTableReaderArchiveImplBase::rspecifier_; + using RandomAccessTableReaderArchiveImplBase::archive_rxfilename_; + using RandomAccessTableReaderArchiveImplBase::ReadNextObject; + + public: + typedef typename Holder::T T; + + RandomAccessTableReaderSortedArchiveImpl(): + last_found_index_(static_cast(-1)), + pending_delete_(static_cast(-1)) { } + + virtual bool Close() { + for (size_t i = 0; i < seen_pairs_.size(); i++) + delete seen_pairs_[i].second; + seen_pairs_.clear(); + + pending_delete_ = static_cast(-1); + last_found_index_ = static_cast(-1); + + return this->CloseInternal(); + } + virtual bool HasKey(const std::string &key) { + HandlePendingDelete(); + size_t index; + bool ans = FindKeyInternal(key, &index); + if (ans && opts_.once && seen_pairs_[index].second == NULL) { + // Just do a check RE the once option. "&&opts_.once" is for + // efficiency since this can only happen in that case. + KALDI_ERR << "Error: HasKey called after Value() already called for " + << " that key, and once (o) option specified: rspecifier is " + << rspecifier_; + } + return ans; + } + virtual const T & Value(const std::string &key) { + HandlePendingDelete(); + size_t index; + if (!FindKeyInternal(key, &index)) { + KALDI_ERR << "Value() called but no such key " << key + << " in archive " << PrintableRxfilename(archive_rxfilename_); + } + if (seen_pairs_[index].second == NULL) { // can happen if opts.once_ + KALDI_ERR << "Error: Value() called more than once for key " + << key << " and once (o) option specified: rspecifier is " + << rspecifier_; + } + if (opts_.once) + pending_delete_ = index; // mark this index to be deleted on next call. + return seen_pairs_[index].second->Value(); + } + virtual ~RandomAccessTableReaderSortedArchiveImpl() { + if (this->IsOpen()) + if (!Close()) // more specific warning will already have been printed. + // we are in some kind of error state & user did not find out by + // calling Close(). + KALDI_ERR << "Error closing RandomAccessTableReader: rspecifier is " + << rspecifier_; + } + private: + void HandlePendingDelete() { + const size_t npos = static_cast(-1); + if (pending_delete_ != npos) { + KALDI_ASSERT(pending_delete_ < seen_pairs_.size()); + KALDI_ASSERT(seen_pairs_[pending_delete_].second != NULL); + delete seen_pairs_[pending_delete_].second; + seen_pairs_[pending_delete_].second = NULL; + pending_delete_ = npos; + } + } + + // FindKeyInternal tries to find the key in the array "seen_pairs_". + // If it is not already there, it reads ahead as far as necessary + // to determine whether we have the key or not. On success it returns + // true and puts the index into the array seen_pairs_, into "index"; + // on failure it returns false. + // It will leave the state as either kNoObject, kEof or kError. + // FindKeyInternal does not do any checking about whether you are asking + // about a key that has been already given (with the "once" option). + // That is the user's responsibility. + + bool FindKeyInternal(const std::string &key, size_t *index) { + // First, an optimization in case the previous call was for the + // same key, and we found it. + if (last_found_index_ < seen_pairs_.size() + && seen_pairs_[last_found_index_].first == key) { + *index = last_found_index_; + return true; + } + + if (state_ == kUninitialized) + KALDI_ERR << "Trying to access a RandomAccessTableReader object that is" + " not open."; + + // Step one is to see whether we have to read ahead for the object.. + // Note, the possible states right now are kNoObject, kEof or kError. + // We are never in the state kHaveObject except just after calling + // ReadNextObject(). + bool looped = false; + while (state_ == kNoObject && + (seen_pairs_.empty() || key.compare(seen_pairs_.back().first) > 0)) { + looped = true; + // Read this as: + // while ( the stream is potentially good for reading && + // ([got no keys] || key > most_recent_key) ) { ... + // Try to read a new object. + // Note that the keys in seen_pairs_ are ordered from least to greatest. + ReadNextObject(); + if (state_ == kHaveObject) { // Successfully read object. + if (!seen_pairs_.empty() && // This is just a check. + cur_key_.compare(seen_pairs_.back().first) <= 0) { + // read the expression above as: !( cur_key_ > previous_key). + // it means we are not in sorted order [the user specified that we + // are, or we would not be using this implementation]. + KALDI_ERR << "You provided the sorted (s) option but keys in archive " + << PrintableRxfilename(archive_rxfilename_) << " are not " + << "in sorted order: " << seen_pairs_.back().first + << " is followed by " << cur_key_; + } + KALDI_ASSERT(holder_ != NULL); + seen_pairs_.push_back(std::make_pair(cur_key_, holder_)); + holder_ = NULL; + state_ = kNoObject; + } + } + if (looped) { // We only need to check the last element of the seen_pairs_ + // array, since we would not have read more after getting "key". + if (!seen_pairs_.empty() && seen_pairs_.back().first == key) { + last_found_index_ = *index = seen_pairs_.size() - 1; + return true; + } else { + return false; + } + } + // Now we have do an actual binary search in the seen_pairs_ array. + std::pair pr(key, static_cast(NULL)); + typename std::vector >::iterator + iter = std::lower_bound(seen_pairs_.begin(), seen_pairs_.end(), + pr, PairCompare()); + if (iter != seen_pairs_.end() && + key == iter->first) { + last_found_index_ = *index = (iter - seen_pairs_.begin()); + return true; + } else { + return false; + } + } + + // These are the pairs of (key, object) we have read. We keep all the keys we + // have read but the actual objects (if they are stored with pointers inside + // the Holder object) may be deallocated if once == true, and the Holder + // pointer set to NULL. + std::vector > seen_pairs_; + size_t last_found_index_; // An optimization s.t. if FindKeyInternal called + // twice with same key (as it often will), it doesn't have to do the key + // search twice. + size_t pending_delete_; // If opts_.once == true, this is the index of + // element of seen_pairs_ that is pending deletion. + struct PairCompare { + // PairCompare is the Less-than operator for the pairs of(key, Holder). + // compares the keys. + inline bool operator() (const std::pair &pr1, + const std::pair &pr2) { + return (pr1.first.compare(pr2.first) < 0); + } + }; +}; + + + +// RandomAccessTableReaderUnsortedArchiveImpl is for random-access reading of +// archives when the user does not specify the sorted (s) option (in this case +// the called-sorted, or "cs" option, is ignored). This is the least efficient +// of the random access archive readers, in general, but it can be as efficient +// as the others, in speed, memory and latency, if the "once" option is +// specified and it happens that the keys of the archive are the same as the +// keys the code is called with (to HasKey() and Value()), and in the same +// order. However, if you ask it for a key that's not present it will have to +// read the archive till the end and store it all in memory. + +template +class RandomAccessTableReaderUnsortedArchiveImpl: + public RandomAccessTableReaderArchiveImplBase { + using RandomAccessTableReaderArchiveImplBase::kUninitialized; + using RandomAccessTableReaderArchiveImplBase::kHaveObject; + using RandomAccessTableReaderArchiveImplBase::kNoObject; + using RandomAccessTableReaderArchiveImplBase::kEof; + using RandomAccessTableReaderArchiveImplBase::kError; + using RandomAccessTableReaderArchiveImplBase::state_; + using RandomAccessTableReaderArchiveImplBase::opts_; + using RandomAccessTableReaderArchiveImplBase::cur_key_; + using RandomAccessTableReaderArchiveImplBase::holder_; + using RandomAccessTableReaderArchiveImplBase::rspecifier_; + using RandomAccessTableReaderArchiveImplBase::archive_rxfilename_; + using RandomAccessTableReaderArchiveImplBase::ReadNextObject; + + typedef typename Holder::T T; + + public: + RandomAccessTableReaderUnsortedArchiveImpl(): to_delete_iter_(map_.end()), + to_delete_iter_valid_(false) { + map_.max_load_factor(0.5); // make it quite empty -> quite efficient. + // default seems to be 1. + } + + virtual bool Close() { + for (typename MapType::iterator iter = map_.begin(); + iter != map_.end(); + ++iter) { + delete iter->second; + } + map_.clear(); + first_deleted_string_ = ""; + to_delete_iter_valid_ = false; + return this->CloseInternal(); + } + + virtual bool HasKey(const std::string &key) { + HandlePendingDelete(); + return FindKeyInternal(key, NULL); + } + virtual const T & Value(const std::string &key) { + HandlePendingDelete(); + const T *ans_ptr = NULL; + if (!FindKeyInternal(key, &ans_ptr)) + KALDI_ERR << "Value() called but no such key " << key + << " in archive " << PrintableRxfilename(archive_rxfilename_); + return *ans_ptr; + } + virtual ~RandomAccessTableReaderUnsortedArchiveImpl() { + if (this->IsOpen()) + if (!Close()) // more specific warning will already have been printed. + // we are in some kind of error state & user did not find out by + // calling Close(). + KALDI_ERR << "Error closing RandomAccessTableReader: rspecifier is " + << rspecifier_; + } + private: + void HandlePendingDelete() { + if (to_delete_iter_valid_) { + to_delete_iter_valid_ = false; + delete to_delete_iter_->second; // Delete Holder object. + if (first_deleted_string_.length() == 0) + first_deleted_string_ = to_delete_iter_->first; + map_.erase(to_delete_iter_); // delete that element. + } + } + + // FindKeyInternal tries to find the key in the map "map_" + // If it is not already there, it reads ahead either until it finds the + // key, or until end of file. If called with value_ptr == NULL, + // it assumes it's called from HasKey() and just returns true or false + // and doesn't otherwise have side effects. If called with value_ptr != + // NULL, it assumes it's called from Value(). Thus, it will crash + // if it cannot find the key. If it can find it it puts its address in + // *value_ptr, and if opts_once == true it will mark that element of the + // map to be deleted. + + bool FindKeyInternal(const std::string &key, const T **value_ptr = NULL) { + typename MapType::iterator iter = map_.find(key); + if (iter != map_.end()) { // Found in the map... + if (value_ptr == NULL) { // called from HasKey + return true; // this is all we have to do. + } else { + *value_ptr = &(iter->second->Value()); + if (opts_.once) { // value won't be needed again, so mark + // for deletion. + to_delete_iter_ = iter; // pending delete. + KALDI_ASSERT(!to_delete_iter_valid_); + to_delete_iter_valid_ = true; + } + return true; + } + } + while (state_ == kNoObject) { + ReadNextObject(); + if (state_ == kHaveObject) { // Successfully read object. + state_ = kNoObject; // we are about to transfer ownership + // of the object in holder_ to map_. + // Insert it into map_. + std::pair pr = + map_.insert(typename MapType::value_type(cur_key_, holder_)); + + if (!pr.second) { // Was not inserted-- previous element w/ same key + delete holder_; // map was not changed, no ownership transferred. + holder_ = NULL; + KALDI_ERR << "Error in RandomAccessTableReader: duplicate key " + << cur_key_ << " in archive " << archive_rxfilename_; + } + holder_ = NULL; // ownership transferred to map_. + if (cur_key_ == key) { // the one we wanted.. + if (value_ptr == NULL) { // called from HasKey + return true; + } else { // called from Value() + *value_ptr = &(pr.first->second->Value()); // this gives us the + // Value() from the Holder in the map. + if (opts_.once) { // mark for deletion, as won't be needed again. + to_delete_iter_ = pr.first; + KALDI_ASSERT(!to_delete_iter_valid_); + to_delete_iter_valid_ = true; + } + return true; + } + } + } + } + if (opts_.once && key == first_deleted_string_) { + KALDI_ERR << "You specified the once (o) option but " + << "you are calling using key " << key + << " more than once: rspecifier is " << rspecifier_; + } + return false; // We read the entire archive (or got to error state) and + // didn't find it. + } + + typedef unordered_map MapType; + MapType map_; + + typename MapType::iterator to_delete_iter_; + bool to_delete_iter_valid_; + + std::string first_deleted_string_; // keep the first string we deleted + // from map_ (if opts_.once == true). It's for an inexact spot-check that the + // "once" option isn't being used incorrectly. +}; + + + + + +template +RandomAccessTableReader::RandomAccessTableReader(const + std::string &rspecifier): + impl_(NULL) { + if (rspecifier != "" && !Open(rspecifier)) + KALDI_ERR << "Error opening RandomAccessTableReader object " + " (rspecifier is: " << rspecifier << ")"; +} + +template +bool RandomAccessTableReader::Open(const std::string &rspecifier) { + if (IsOpen()) + KALDI_ERR << "Already open."; + RspecifierOptions opts; + RspecifierType rs = ClassifyRspecifier(rspecifier, NULL, &opts); + switch (rs) { + case kScriptRspecifier: + impl_ = new RandomAccessTableReaderScriptImpl(); + break; + case kArchiveRspecifier: + if (opts.sorted) { + if (opts.called_sorted) // "doubly" sorted case. + impl_ = new RandomAccessTableReaderDSortedArchiveImpl(); + else + impl_ = new RandomAccessTableReaderSortedArchiveImpl(); + } else { + impl_ = new RandomAccessTableReaderUnsortedArchiveImpl(); + } + break; + case kNoRspecifier: default: + KALDI_WARN << "Invalid rspecifier: " + << rspecifier; + return false; + } + if (!impl_->Open(rspecifier)) { + // A warning will already have been printed. + delete impl_; + impl_ = NULL; + return false; + } + return true; +} + +template +bool RandomAccessTableReader::HasKey(const std::string &key) { + CheckImpl(); + if (!IsToken(key)) + KALDI_ERR << "Invalid key \"" << key << '"'; + return impl_->HasKey(key); +} + + +template +const typename RandomAccessTableReader::T& +RandomAccessTableReader::Value(const std::string &key) { + CheckImpl(); + return impl_->Value(key); +} + +template +bool RandomAccessTableReader::Close() { + CheckImpl(); + bool ans =impl_->Close(); + delete impl_; + impl_ = NULL; + return ans; +} + +template +RandomAccessTableReader::~RandomAccessTableReader() { + if (IsOpen() && !Close()) // call Close() yourself to stop this being thrown. + KALDI_ERR << "failure detected in destructor."; +} + +template +void SequentialTableReader::CheckImpl() const { + if (!impl_) { + KALDI_ERR << "Trying to use empty SequentialTableReader (perhaps you " + << "passed the empty string as an argument to a program?)"; + } +} + +template +void RandomAccessTableReader::CheckImpl() const { + if (!impl_) { + KALDI_ERR << "Trying to use empty RandomAccessTableReader (perhaps you " + << "passed the empty string as an argument to a program?)"; + } +} + +template +void TableWriter::CheckImpl() const { + if (!impl_) { + KALDI_ERR << "Trying to use empty TableWriter (perhaps you " + << "passed the empty string as an argument to a program?)"; + } +} + +template +RandomAccessTableReaderMapped::RandomAccessTableReaderMapped( + const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename): + reader_(table_rxfilename), token_reader_(table_rxfilename.empty() ? "" : + utt2spk_rxfilename), + utt2spk_rxfilename_(utt2spk_rxfilename) { } + +template +bool RandomAccessTableReaderMapped::Open( + const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename) { + if (reader_.IsOpen()) reader_.Close(); + if (token_reader_.IsOpen()) token_reader_.Close(); + KALDI_ASSERT(!table_rxfilename.empty()); + if (!reader_.Open(table_rxfilename)) return false; // will have printed + // warning internally, probably. + if (!utt2spk_rxfilename.empty()) { + if (!token_reader_.Open(utt2spk_rxfilename)) { + reader_.Close(); + return false; + } + } + return true; +} + + +template +bool RandomAccessTableReaderMapped::HasKey(const std::string &utt) { + // We don't check IsOpen, we let the call go through to the member variable + // (reader_), which will crash with a more informative error message than + // we can give here, as we don't any longer know the rxfilename. + if (token_reader_.IsOpen()) { // We need to map the key from utt to spk. + if (!token_reader_.HasKey(utt)) + KALDI_ERR << "Attempting to read key " << utt << ", which is not present " + << "in utt2spk map or similar map being read from " + << PrintableRxfilename(utt2spk_rxfilename_); + const std::string &spk = token_reader_.Value(utt); + return reader_.HasKey(spk); + } else { + return reader_.HasKey(utt); + } +} + +template +const typename Holder::T& RandomAccessTableReaderMapped::Value( + const std::string &utt) { + if (token_reader_.IsOpen()) { // We need to map the key from utt to spk. + if (!token_reader_.HasKey(utt)) + KALDI_ERR << "Attempting to read key " << utt << ", which is not present " + << "in utt2spk map or similar map being read from " + << PrintableRxfilename(utt2spk_rxfilename_); + const std::string &spk = token_reader_.Value(utt); + return reader_.Value(spk); + } else { + return reader_.Value(utt); + } +} + + + +/// @} + +} // end namespace kaldi + + + +#endif // KALDI_UTIL_KALDI_TABLE_INL_H_ diff --git a/speechx/speechx/kaldi/util/kaldi-table.cc b/speechx/speechx/kaldi/util/kaldi-table.cc new file mode 100644 index 00000000..1aeceb2b --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-table.cc @@ -0,0 +1,321 @@ +// util/kaldi-table.cc + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "util/kaldi-table.h" +#include "util/text-utils.h" + +namespace kaldi { + + +bool ReadScriptFile(const std::string &rxfilename, + bool warn, + std::vector > + *script_out) { + bool is_binary; + Input input; + + if (!input.Open(rxfilename, &is_binary)) { + if (warn) KALDI_WARN << "Error opening script file: " << + PrintableRxfilename(rxfilename); + return false; + } + if (is_binary) { + if (warn) KALDI_WARN << "Error: script file appears to be binary: " << + PrintableRxfilename(rxfilename); + return false; + } + + bool ans = ReadScriptFile(input.Stream(), warn, script_out); + if (warn && !ans) + KALDI_WARN << "[script file was: " << PrintableRxfilename(rxfilename) << + "]"; + return ans; +} + +bool ReadScriptFile(std::istream &is, + bool warn, + std::vector > + *script_out) { + KALDI_ASSERT(script_out != NULL); + std::string line; + int line_number = 0; + while (getline(is, line)) { + line_number++; + const char *c = line.c_str(); + if (*c == '\0') { + if (warn) + KALDI_WARN << "Empty " << line_number << "'th line in script file"; + return false; // Empty line so invalid scp file format.. + } + + std::string key, rest; + SplitStringOnFirstSpace(line, &key, &rest); + + if (key.empty() || rest.empty()) { + if (warn) + KALDI_WARN << "Invalid " << line_number << "'th line in script file" + <<":\"" << line << '"'; + return false; + } + script_out->resize(script_out->size()+1); + script_out->back().first = key; + script_out->back().second = rest; + } + return true; +} + +bool WriteScriptFile(std::ostream &os, + const std::vector > + &script) { + if (!os.good()) { + KALDI_WARN << "WriteScriptFile: attempting to write to invalid stream."; + return false; + } + std::vector >::const_iterator iter; + for (iter = script.begin(); iter != script.end(); ++iter) { + if (!IsToken(iter->first)) { + KALDI_WARN << "WriteScriptFile: using invalid token \"" << iter->first << + '"'; + return false; + } + if (iter->second.find('\n') != std::string::npos || + (iter->second.length() != 0 && + (isspace(iter->second[0]) || + isspace(iter->second[iter->second.length()-1])))) { + // second part contains newline or leading or trailing space. + KALDI_WARN << "WriteScriptFile: attempting to write invalid line \"" << + iter->second << '"'; + return false; + } + os << iter->first << ' ' << iter->second << '\n'; + } + if (!os.good()) { + KALDI_WARN << "WriteScriptFile: stream in error state."; + return false; + } + return true; +} + +bool WriteScriptFile(const std::string &wxfilename, + const std::vector > + &script) { + Output output; + if (!output.Open(wxfilename, false, false)) { // false, false means not + // binary, no binary-mode header. + KALDI_ERR << "Error opening output stream for script file: " + << PrintableWxfilename(wxfilename); + return false; + } + if (!WriteScriptFile(output.Stream(), script)) { + KALDI_ERR << "Error writing script file to stream " + << PrintableWxfilename(wxfilename); + return false; + } + return true; +} + + + +WspecifierType ClassifyWspecifier(const std::string &wspecifier, + std::string *archive_wxfilename, + std::string *script_wxfilename, + WspecifierOptions *opts) { + // Examples: + // ark,t:wxfilename -> kArchiveWspecifier + // ark,b:wxfilename -> kArchiveWspecifier + // scp,t:rxfilename -> kScriptWspecifier + // scp,t:rxfilename -> kScriptWspecifier + // ark,scp,t:filename, wxfilename -> kBothWspecifier + // ark,scp:filename, wxfilename -> kBothWspecifier + // Note we can include the flush option (f) or no-flush (nf) + // anywhere: e.g. + // ark,scp,f:filename, wxfilename -> kBothWspecifier + // or: + // scp,t,nf:rxfilename -> kScriptWspecifier + + if (archive_wxfilename) archive_wxfilename->clear(); + if (script_wxfilename) script_wxfilename->clear(); + + size_t pos = wspecifier.find(':'); + if (pos == std::string::npos) return kNoWspecifier; + if (isspace(*(wspecifier.rbegin()))) return kNoWspecifier; // Trailing space + // disallowed. + + std::string before_colon(wspecifier, 0, pos), after_colon(wspecifier, pos+1); + + std::vector split_first_part; // Split part before ':' on ', '. + SplitStringToVector(before_colon, ", ", false, &split_first_part); // false== + // don't omit empty strings between commas. + + WspecifierType ws = kNoWspecifier; + + if (opts != NULL) + *opts = WspecifierOptions(); // Make sure all the defaults are as in the + // default constructor of the options class. + + for (size_t i = 0; i < split_first_part.size(); i++) { + const std::string &str = split_first_part[i]; // e.g. "b", "t", "f", "ark", + // "scp". + const char *c = str.c_str(); + if (!strcmp(c, "b")) { + if (opts) opts->binary = true; + } else if (!strcmp(c, "f")) { + if (opts) opts->flush = true; + } else if (!strcmp(c, "nf")) { + if (opts) opts->flush = false; + } else if (!strcmp(c, "t")) { + if (opts) opts->binary = false; + } else if (!strcmp(c, "p")) { + if (opts) opts->permissive = true; + } else if (!strcmp(c, "ark")) { + if (ws == kNoWspecifier) ws = kArchiveWspecifier; + else + return kNoWspecifier; // We do not allow "scp, ark", only "ark, + // scp". + } else if (!strcmp(c, "scp")) { + if (ws == kNoWspecifier) ws = kScriptWspecifier; + else if (ws == kArchiveWspecifier) ws = kBothWspecifier; + else + return kNoWspecifier; // repeated "scp" option: invalid. + } else { + return kNoWspecifier; // Could not interpret this option. + } + } + + switch (ws) { + case kArchiveWspecifier: + if (archive_wxfilename) + *archive_wxfilename = after_colon; + break; + case kScriptWspecifier: + if (script_wxfilename) + *script_wxfilename = after_colon; + break; + case kBothWspecifier: + pos = after_colon.find(','); // first comma. + if (pos == std::string::npos) return kNoWspecifier; + if (archive_wxfilename) + *archive_wxfilename = std::string(after_colon, 0, pos); + if (script_wxfilename) + *script_wxfilename = std::string(after_colon, pos+1); + break; + case kNoWspecifier: default: break; + } + return ws; +} + + + +RspecifierType ClassifyRspecifier(const std::string &rspecifier, + std::string *rxfilename, + RspecifierOptions *opts) { + // Examples + // ark:rxfilename -> kArchiveRspecifier + // scp:rxfilename -> kScriptRspecifier + // + // We also allow the meaningless prefixes b, and t, + // plus the options o (once), no (not-once), + // s (sorted) and ns (not-sorted), p (permissive) + // and np (not-permissive). + // so the following would be valid: + // + // f, o, b, np, ark:rxfilename -> kArchiveRspecifier + // + // Examples: + // + // b, ark:rxfilename -> kArchiveRspecifier + // t, ark:rxfilename -> kArchiveRspecifier + // b, scp:rxfilename -> kScriptRspecifier + // t, no, s, scp:rxfilename -> kScriptRspecifier + // t, ns, scp:rxfilename -> kScriptRspecifier + + // Improperly formed Rspecifiers will be classified as kNoRspecifier. + + if (rxfilename) rxfilename->clear(); + + if (opts != NULL) + *opts = RspecifierOptions(); // Make sure all the defaults are as in the + // default constructor of the options class. + + size_t pos = rspecifier.find(':'); + if (pos == std::string::npos) return kNoRspecifier; + + if (isspace(*(rspecifier.rbegin()))) return kNoRspecifier; // Trailing space + // disallowed. + + std::string before_colon(rspecifier, 0, pos), + after_colon(rspecifier, pos+1); + + std::vector split_first_part; // Split part before ':' on ', '. + SplitStringToVector(before_colon, ", ", false, &split_first_part); // false== + // don't omit empty strings between commas. + + RspecifierType rs = kNoRspecifier; + + for (size_t i = 0; i < split_first_part.size(); i++) { + const std::string &str = split_first_part[i]; // e.g. "b", "t", "f", "ark", + // "scp". + const char *c = str.c_str(); + if (!strcmp(c, "b")); // Ignore this option. It's so we can use the same + // specifiers for rspecifiers and wspecifiers. + else if (!strcmp(c, "t")); // Ignore this option too. + else if (!strcmp(c, "o")) { + if (opts) opts->once = true; + } else if (!strcmp(c, "no")) { + if (opts) opts->once = false; + } else if (!strcmp(c, "p")) { + if (opts) opts->permissive = true; + } else if (!strcmp(c, "np")) { + if (opts) opts->permissive = false; + } else if (!strcmp(c, "s")) { + if (opts) opts->sorted = true; + } else if (!strcmp(c, "ns")) { + if (opts) opts->sorted = false; + } else if (!strcmp(c, "cs")) { + if (opts) opts->called_sorted = true; + } else if (!strcmp(c, "ncs")) { + if (opts) opts->called_sorted = false; + } else if (!strcmp(c, "bg")) { + if (opts) opts->background = true; + } else if (!strcmp(c, "ark")) { + if (rs == kNoRspecifier) rs = kArchiveRspecifier; + else + return kNoRspecifier; // Repeated or combined ark and scp options + // invalid. + } else if (!strcmp(c, "scp")) { + if (rs == kNoRspecifier) rs = kScriptRspecifier; + else + return kNoRspecifier; // Repeated or combined ark and scp options + // invalid. + } else { + return kNoRspecifier; // Could not interpret this option. + } + } + if ((rs == kArchiveRspecifier || rs == kScriptRspecifier) + && rxfilename != NULL) + *rxfilename = after_colon; + return rs; +} + + + + + + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/util/kaldi-table.h b/speechx/speechx/kaldi/util/kaldi-table.h new file mode 100644 index 00000000..6865cea1 --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-table.h @@ -0,0 +1,471 @@ +// util/kaldi-table.h + +// Copyright 2009-2011 Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_KALDI_TABLE_H_ +#define KALDI_UTIL_KALDI_TABLE_H_ + +#include +#include +#include + +#include "base/kaldi-common.h" +#include "util/kaldi-holder.h" + +namespace kaldi { + +// Forward declarations +template class RandomAccessTableReaderImplBase; +template class SequentialTableReaderImplBase; +template class TableWriterImplBase; + +/// \addtogroup table_group +/// @{ + +// This header defines the Table classes (RandomAccessTableReader, +// SequentialTableReader and TableWriter) and explains what the Holder classes, +// which the Table class requires as a template argument, are like. It also +// explains the "rspecifier" and "wspecifier" concepts (these are strings that +// explain how to read/write objects via archives or scp files. A table is +// conceptually a collection of objects of a particular type T indexed by keys +// of type std::string (these Keys additionally have an order within +// each table). +// The Table classes are templated on a type (call it Holder) such that +// Holder::T is a typedef equal to T. + +// see kaldi-holder.h for detail on the Holder classes. + +typedef std::vector KeyList; + +// Documentation for "wspecifier" +// "wspecifier" describes how we write a set of objects indexed by keys. +// The basic, unadorned wspecifiers are as follows: +// +// ark:wxfilename +// scp:rxfilename +// ark,scp:filename,wxfilename +// ark,scp:filename,wxfilename +// +// +// We also allow the following modifiers: +// t means text mode. +// b means binary mode. +// f means flush the stream after writing each entry. +// (nf means don't flush, and the default is not to flush). +// p means permissive mode, when writing to an "scp" file only: will ignore +// missing scp entries, i.e. won't write anything for those files but will +// return success status). +// +// So the following are valid wspecifiers: +// ark,b,f:foo +// "ark,b,b:| gzip -c > foo" +// "ark,scp,t,nf:foo.ark,|gzip -c > foo.scp.gz" +// ark,b:- +// +// The meanings of rxfilename and wxfilename are as described in +// kaldi-io.h (they are filenames but include pipes, stdin/stdout +// and so on; filename is a regular filename. +// + +// The ark:wxfilename type of wspecifier instructs the class to +// write directly to an archive. For small objects (e.g. lists of ints), +// the text archive format will generally be human readable with one line +// per entry in the archive. +// +// The type "scp:xfilename" refers to an scp file which should +// already exist on disk, and tells us where to write the data for +// each key (usually an actual file); each line of the scp file +// would be: +// key xfilename +// +// The type ark,scp:filename,wxfilename means +// we write both an archive and an scp file that specifies offsets into the +// archive, with lines like: +// key filename:12407 +// where the number is the byte offset into the file. +// In this case we restrict the archive-filename to be an actual filename, +// as we can't see a situation where an extended filename would make sense +// for this (we can't fseek() in pipes). + +enum WspecifierType { + kNoWspecifier, + kArchiveWspecifier, + kScriptWspecifier, + kBothWspecifier +}; + +struct WspecifierOptions { + bool binary; + bool flush; + bool permissive; // will ignore absent scp entries. + WspecifierOptions(): binary(true), flush(false), permissive(false) { } +}; + +// ClassifyWspecifier returns the type of the wspecifier string, +// and (if pointers are non-NULL) outputs the extra information +// about the options, and the script and archive +// filenames. +WspecifierType ClassifyWspecifier(const std::string &wspecifier, + std::string *archive_wxfilename, + std::string *script_wxfilename, + WspecifierOptions *opts); + +// ReadScriptFile reads an .scp file in its entirety, and appends it +// (in order as it was in the scp file) in script_out_, which contains +// pairs of (key, xfilename). The .scp +// file format is: on each line, key xfilename +// where xfilename means rxfilename or wxfilename, and may contain internal +// spaces (we trim away any leading or trailing space). The key is space-free. +// ReadScriptFile returns true if the format was valid (empty files +// are valid). +// If 'print_warnings', it will print out warning messages that explain what +// kind of error there was. +bool ReadScriptFile(const std::string &rxfilename, + bool print_warnings, + std::vector > + *script_out); + +// This version of ReadScriptFile works from an istream. +bool ReadScriptFile(std::istream &is, + bool print_warnings, + std::vector > + *script_out); + +// Writes, for each entry in script, the first element, then ' ', then the +// second element then '\n'. Checks that the keys (first elements of pairs) are +// valid tokens (nonempty, no whitespace), and the values (second elements of +// pairs) are newline-free and contain no leading or trailing space. Returns +// true on success. +bool WriteScriptFile(const std::string &wxfilename, + const std::vector > + &script); + +// This version writes to an ostream. +bool WriteScriptFile(std::ostream &os, + const std::vector > + &script); + +// Documentation for "rspecifier" +// "rspecifier" describes how we read a set of objects indexed by keys. +// The possibilities are: +// +// ark:rxfilename +// scp:rxfilename +// +// We also allow various modifiers: +// o means the program will only ask for each key once, which enables +// the reader to discard already-asked-for values. +// s means the keys are sorted on input (means we don't have to read till +// eof if someone asked for a key that wasn't there). +// cs means that it is called in sorted order (we are generally asserting +// this based on knowledge of how the program works). +// p means "permissive", and causes it to skip over keys whose corresponding +// scp-file entries cannot be read. [and to ignore errors in archives and +// script files, and just consider the "good" entries]. +// We allow the negation of the options above, as in no, ns, np, +// but these aren't currently very useful (just equivalent to omitting the +// corresponding option). +// [any of the above options can be prefixed by n to negate them, e.g. no, +// ns, ncs, np; but these aren't currently useful as you could just omit +// the option]. +// bg means "background". It currently has no effect for random-access readers, +// but for sequential readers it will cause it to "read ahead" to the next +// value, in a background thread. Recommended when reading larger objects +// such as neural-net training examples, especially when you want to +// maximize GPU usage. +// +// b is ignored [for scripting convenience] +// t is ignored [for scripting convenience] +// +// +// So for instance the following would be a valid rspecifier: +// +// "o, s, p, ark:gunzip -c foo.gz|" + +struct RspecifierOptions { + // These options only make a difference for the RandomAccessTableReader class. + bool once; // we assert that the program will only ask for each key once. + bool sorted; // we assert that the keys are sorted. + bool called_sorted; // we assert that the (HasKey(), Value() functions will + // also be called in sorted order. [this implies "once" but not vice versa]. + bool permissive; // If "permissive", when reading from scp files it treats + // scp files that can't be read as if the corresponding key were not there. + // For archive files it will suppress errors getting thrown if the archive + // is corrupted and can't be read to the end. + bool background; // For sequential readers, if the background option ("bg") + // is provided, it will read ahead to the next object in a + // background thread. + RspecifierOptions(): once(false), sorted(false), + called_sorted(false), permissive(false), + background(false) { } +}; + +enum RspecifierType { + kNoRspecifier, + kArchiveRspecifier, + kScriptRspecifier +}; + +RspecifierType ClassifyRspecifier(const std::string &rspecifier, + std::string *rxfilename, + RspecifierOptions *opts); + + +/// Allows random access to a collection +/// of objects in an archive or script file; see \ref io_sec_tables. +template +class RandomAccessTableReader { + public: + typedef typename Holder::T T; + + RandomAccessTableReader(): impl_(NULL) { } + + // This constructor is equivalent to default constructor + "open", but + // throws on error. + explicit RandomAccessTableReader(const std::string &rspecifier); + + // Opens the table. + bool Open(const std::string &rspecifier); + + // Returns true if table is open. + bool IsOpen() const { return (impl_ != NULL); } + + // Close() will close the table [throws if it was not open], + // and returns true on success (false if we were reading an + // archive and we discovered an error in the archive). + bool Close(); + + // Says if it has this key. + // If you are using the "permissive" (p) read option, + // it will return false for keys whose corresponding entry + // in the scp file cannot be read. + + bool HasKey(const std::string &key); + + // Value() may throw if you are reading an scp file, you + // do not have the "permissive" (p) option, and an entry + // in the scp file cannot be read. Typically you won't + // want to catch this error. + const T &Value(const std::string &key); + + ~RandomAccessTableReader(); + + // Allow copy-constructor only for non-opened readers (needed for inclusion in + // stl vector) + RandomAccessTableReader(const RandomAccessTableReader + &other): + impl_(NULL) { KALDI_ASSERT(other.impl_ == NULL); } + private: + // Disallow assignment. + RandomAccessTableReader &operator=(const RandomAccessTableReader&); + void CheckImpl() const; // Checks that impl_ is non-NULL; prints an error + // message and dies (with KALDI_ERR) if NULL. + RandomAccessTableReaderImplBase *impl_; +}; + + + +/// A templated class for reading objects sequentially from an archive or script +/// file; see \ref io_sec_tables. +template +class SequentialTableReader { + public: + typedef typename Holder::T T; + + SequentialTableReader(): impl_(NULL) { } + + // This constructor equivalent to default constructor + "open", but + // throws on error. + explicit SequentialTableReader(const std::string &rspecifier); + + // Opens the table. Returns exit status; but does throw if previously open + // stream was in error state. You can call Close to prevent this; anyway, + // calling Open more than once is not usually needed. + bool Open(const std::string &rspecifier); + + // Returns true if we're done. It will also return true if there's some kind + // of error and we can't read any more; in this case, you can detect the + // error by calling Close and checking the return status; otherwise + // the destructor will throw. + inline bool Done(); + + // Only valid to call Key() if Done() returned false. + inline std::string Key(); + + // FreeCurrent() is provided as an optimization to save memory, for large + // objects. It instructs the class to deallocate the current value. The + // reference Value() will be invalidated by this. + void FreeCurrent(); + + // Return reference to the current value. It's only valid to call this if + // Done() returned false. The reference is valid till next call to this + // object. It will throw if you are reading an scp file, did not specify the + // "permissive" (p) option and the file cannot be read. [The permissive + // option makes it behave as if that key does not even exist, if the + // corresponding file cannot be read.] You probably wouldn't want to catch + // this exception; the user can just specify the p option in the rspecifier. + // We make this non-const to enable things like shallow swap on the held + // object in situations where this would avoid making a redundant copy. + T &Value(); + + // Next goes to the next key. It will not throw; any error will + // result in Done() returning true, and then the destructor will + // throw unless you call Close(). + void Next(); + + // Returns true if table is open for reading (does not imply + // stream is in good state). + bool IsOpen() const; + + // Close() will return false (failure) if Done() became true + // because of an error/ condition rather than because we are + // really done [e.g. because of an error or early termination + // in the archive]. + // If there is an error and you don't call Close(), the destructor + // will fail. + // Close() + bool Close(); + + // The destructor may throw. This is the desired behaviour, as it's the way + // we signal the error to the user (to detect it, call Close(). The issue is + // that otherwise the user has no way to tell whether Done() returned true + // because we reached the end of the archive or script, or because there was + // an error that prevented further reading. + ~SequentialTableReader(); + + // Allow copy-constructor only for non-opened readers (needed for inclusion in + // stl vector) + SequentialTableReader(const SequentialTableReader &other): + impl_(NULL) { KALDI_ASSERT(other.impl_ == NULL); } + private: + // Disallow assignment. + SequentialTableReader &operator = (const SequentialTableReader&); + void CheckImpl() const; // Checks that impl_ is non-NULL; prints an error + // message and dies (with KALDI_ERR) if NULL. + SequentialTableReaderImplBase *impl_; +}; + + +/// A templated class for writing objects to an +/// archive or script file; see \ref io_sec_tables. +template +class TableWriter { + public: + typedef typename Holder::T T; + + TableWriter(): impl_(NULL) { } + + // This constructor equivalent to default constructor + // + "open", but throws on error. See docs for + // wspecifier above. + explicit TableWriter(const std::string &wspecifier); + + // Opens the table. See docs for wspecifier above. + // If it returns true, it is open. + bool Open(const std::string &wspecifier); + + // Returns true if open for writing. + bool IsOpen() const; + + // Write the object. Throws KaldiFatalError on error via the KALDI_ERR macro. + inline void Write(const std::string &key, const T &value) const; + + + // Flush will flush any archive; it does not return error status + // or throw, any errors will be reported on the next Write or Close. + // Useful if we may be writing to a command in a pipe and want + // to ensure good CPU utilization. + void Flush(); + + // Close() is not necessary to call, as the destructor + // closes it; it's mainly useful if you want to handle + // error states because the destructor will throw on + // error if you do not call Close(). + bool Close(); + + ~TableWriter(); + + // Allow copy-constructor only for non-opened writers (needed for inclusion in + // stl vector) + TableWriter(const TableWriter &other): impl_(NULL) { + KALDI_ASSERT(other.impl_ == NULL); + } + private: + TableWriter &operator = (const TableWriter&); // Disallow assignment. + + void CheckImpl() const; // Checks that impl_ is non-NULL; prints an error + // message and dies (with KALDI_ERR) if NULL. + TableWriterImplBase *impl_; +}; + + +/// This class is for when you are reading something in random access, but +/// it may actually be stored per-speaker (or something similar) but the +/// keys you're using are per utterance. So you also provide an "rxfilename" +/// for a file containing lines like +/// utt1 spk1 +/// utt2 spk1 +/// utt3 spk1 +/// and so on. Note: this is optional; if it is an empty string, we just won't +/// do the mapping. Also, "table_rxfilename" may be the empty string (as for +/// a regular table), in which case the table just won't be opened. +/// We provide only the most frequently used of the functions of +/// RandomAccessTableReader. + +template +class RandomAccessTableReaderMapped { + public: + typedef typename Holder::T T; + /// Note: "utt2spk_rxfilename" will in the normal case be an rxfilename + /// for an utterance to speaker map, but this code is general; it accepts + /// a generic map. + RandomAccessTableReaderMapped(const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename); + + RandomAccessTableReaderMapped() {} + + /// Note: when calling Open, utt2spk_rxfilename may be empty. + bool Open(const std::string &table_rxfilename, + const std::string &utt2spk_rxfilename); + + bool HasKey(const std::string &key); + const T &Value(const std::string &key); + inline bool IsOpen() const { return reader_.IsOpen(); } + inline bool Close() { return reader_.Close(); } + + + + // The default copy-constructor will do what we want: it will crash for + // already-opened readers, by calling the member-variable copy-constructors. + private: + // Disallow assignment. + RandomAccessTableReaderMapped &operator = + (const RandomAccessTableReaderMapped&); + RandomAccessTableReader reader_; + RandomAccessTableReader token_reader_; + std::string utt2spk_rxfilename_; // Used only in diagnostic messages. +}; + + +/// @} end "addtogroup table_group" +} // end namespace kaldi + +#include "util/kaldi-table-inl.h" + +#endif // KALDI_UTIL_KALDI_TABLE_H_ diff --git a/speechx/speechx/kaldi/util/kaldi-thread.cc b/speechx/speechx/kaldi/util/kaldi-thread.cc new file mode 100644 index 00000000..2405d01f --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-thread.cc @@ -0,0 +1,33 @@ +// util/kaldi-thread.cc + +// Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +// Frantisek Skala + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-common.h" +#include "util/kaldi-thread.h" + +namespace kaldi { +int32 g_num_threads = 8; // Initialize this global variable. + +MultiThreadable::~MultiThreadable() { + // default implementation does nothing +} + + + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/util/kaldi-thread.h b/speechx/speechx/kaldi/util/kaldi-thread.h new file mode 100644 index 00000000..50bf7dac --- /dev/null +++ b/speechx/speechx/kaldi/util/kaldi-thread.h @@ -0,0 +1,284 @@ +// util/kaldi-thread.h + +// Copyright 2012 Johns Hopkins University (Author: Daniel Povey) +// Frantisek Skala +// 2017 University of Southern California (Author: Dogan Can) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_THREAD_KALDI_THREAD_H_ +#define KALDI_THREAD_KALDI_THREAD_H_ 1 + +#include +#include "util/options-itf.h" +#include "util/kaldi-semaphore.h" + +// This header provides convenient mechanisms for parallelization. +// +// The class MultiThreader, and the function RunMultiThreaded provide a +// mechanism to run a specified number of jobs in parellel and wait for them +// all to finish. They accept objects of some class C that derives from the +// base class MultiThreadable. C needs to define the operator () that takes +// no arguments. See ExampleClass below. +// +// The class TaskSequencer addresses a different problem typically encountered +// in Kaldi command-line programs that process a sequence of items. The items +// to be processed are coming in. They are all of different sizes, e.g. +// utterances with different numbers of frames. We would like them to be +// processed in parallel to make good use of the threads available but they +// must be output in the same order they came in. Here, we again accept objects +// of some class C with an operator () that takes no arguments. C may also have +// a destructor with side effects (typically some kind of output). +// TaskSequencer is responsible for running the jobs in parallel. It has a +// function Run() that will accept a new object of class C; this will block +// until a thread is free, at which time it will spawn a thread that starts +// running the operator () of the object. When threads are finished running, +// the objects will be deleted. TaskSequencer guarantees that the destructors +// will be called sequentially (not in parallel) and in the same order the +// objects were given to the Run() function, so that it is safe for the +// destructor to have side effects such as outputting data. +// Note: the destructor of TaskSequencer will wait for any remaining jobs that +// are still running and will call the destructors. + + +namespace kaldi { + +extern int32 g_num_threads; // Maximum number of threads (for programs that +// use threads, which is not many of them, e.g. the SGMM update program does. +// This is 8 by default. You can change this on the command line, where +// used, with --num-threads. Programs that think they will use threads +// should register it with their ParseOptions, as something like: +// po.Register("num-threads", &g_num_threads, "Number of threads to use."); + +class MultiThreadable { + // To create a function object that does part of the job, inherit from this + // class, implement a copy constructor calling the default copy constructor + // of this base class (so that thread_id_ and num_threads_ are copied to new + // instances), and finally implement the operator() that does part of the job + // based on thread_id_ and num_threads_ variables. + // Note: example implementations are in util/kaldi-thread-test.cc + public: + virtual void operator() () = 0; + // Does the main function of the class + // Subclasses have to redefine this + virtual ~MultiThreadable(); + // Optional destructor. Note: the destructor of the object passed by the user + // will also be called, so watch out. + + public: + // Do not redeclare thread_id_ and num_threads_ in derived classes. + int32 thread_id_; // 0 <= thread_id_ < num_threads_ + int32 num_threads_; + + private: + // Have additional member variables as needed. +}; + + +class ExampleClass: public MultiThreadable { + public: + ExampleClass(int32 *foo); // Typically there will be an initializer that + // takes arguments. + + ExampleClass(const ExampleClass &other); // A copy constructor is also needed; + // some example classes use the default version of this. + + void operator() () { + // Does the main function of the class. This + // function will typically want to look at the values of the + // member variables thread_id_ and num_threads_, inherited + // from MultiThreadable. + } + ~ExampleClass() { + // Optional destructor. Sometimes useful things happen here, + // for example summing up of certain quantities. See code + // that uses RunMultiThreaded for examples. + } + private: + // Have additional member variables as needed. +}; + + +template +class MultiThreader { + public: + MultiThreader(int32 num_threads, const C &c_in) : + threads_(std::max(1, num_threads)), + cvec_(std::max(1, num_threads), c_in) { + if (num_threads == 0) { + // This is a special case with num_threads == 0, which behaves like with + // num_threads == 1 but without creating extra threads. This can be + // useful in GPU computations where threads cannot be used. + cvec_[0].thread_id_ = 0; + cvec_[0].num_threads_ = 1; + (cvec_[0])(); + } else { + for (int32 i = 0; i < threads_.size(); i++) { + cvec_[i].thread_id_ = i; + cvec_[i].num_threads_ = threads_.size(); + threads_[i] = std::thread(std::ref(cvec_[i])); + } + } + } + ~MultiThreader() { + for (size_t i = 0; i < threads_.size(); i++) + if (threads_[i].joinable()) + threads_[i].join(); + } + private: + std::vector threads_; + std::vector cvec_; +}; + +/// Here, class C should inherit from MultiThreadable. Note: if you want to +/// control the number of threads yourself, or need to do something in the main +/// thread of the program while the objects exist, just initialize the +/// MultiThreader object yourself. +template void RunMultiThreaded(const C &c_in) { + MultiThreader m(g_num_threads, c_in); +} + + +struct TaskSequencerConfig { + int32 num_threads; + int32 num_threads_total; + TaskSequencerConfig(): num_threads(1), num_threads_total(0) { } + void Register(OptionsItf *opts) { + opts->Register("num-threads", &num_threads, "Number of actively processing " + "threads to run in parallel"); + opts->Register("num-threads-total", &num_threads_total, "Total number of " + "threads, including those that are waiting on other threads " + "to produce their output. Controls memory use. If <= 0, " + "defaults to --num-threads plus 20. Otherwise, must " + "be >= num-threads."); + } +}; + +// C should have an operator () taking no arguments, that does some kind +// of computation, and a destructor that produces some kind of output (the +// destructors will be run sequentially in the same order Run as called. +template +class TaskSequencer { + public: + TaskSequencer(const TaskSequencerConfig &config): + num_threads_(config.num_threads), + threads_avail_(config.num_threads), + tot_threads_avail_(config.num_threads_total > 0 ? config.num_threads_total : + config.num_threads + 20), + thread_list_(NULL) { + KALDI_ASSERT((config.num_threads_total <= 0 || + config.num_threads_total >= config.num_threads) && + "num-threads-total, if specified, must be >= num-threads"); + } + + /// This function takes ownership of the pointer "c", and will delete it + /// in the same sequence as Run was called on the jobs. + void Run(C *c) { + // run in main thread + if (num_threads_ == 0) { + (*c)(); + delete c; + return; + } + + threads_avail_.Wait(); // wait till we have a thread for computation free. + tot_threads_avail_.Wait(); // this ensures we don't have too many threads + // waiting on I/O, and consume too much memory. + + // put the new RunTaskArgsList object at head of the singly + // linked list thread_list_. + thread_list_ = new RunTaskArgsList(this, c, thread_list_); + thread_list_->thread = std::thread(TaskSequencer::RunTask, + thread_list_); + } + + void Wait() { // You call this at the end if it's more convenient + // than waiting for the destructor. It waits for all tasks to finish. + if (thread_list_ != NULL) { + thread_list_->thread.join(); + KALDI_ASSERT(thread_list_->tail == NULL); // thread would not + // have exited without setting tail to NULL. + delete thread_list_; + thread_list_ = NULL; + } + } + + /// The destructor waits for the last thread to exit. + ~TaskSequencer() { + Wait(); + } + private: + struct RunTaskArgsList { + TaskSequencer *me; // Think of this as a "this" pointer. + C *c; // Clist element of the task we're expected + std::thread thread; + RunTaskArgsList *tail; + RunTaskArgsList(TaskSequencer *me, C *c, RunTaskArgsList *tail): + me(me), c(c), tail(tail) {} + }; + // This static function gets run in the threads that we create. + static void RunTask(RunTaskArgsList *args) { + // (1) run the job. + (*(args->c))(); // call operator () on args->c, which does the computation. + args->me->threads_avail_.Signal(); // Signal that the compute-intensive + // part of the thread is done (we want to run no more than + // config_.num_threads of these.) + + // (2) we want to destroy the object "c" now, by deleting it. But for + // correct sequencing (this is the whole point of this class, it + // is intended to ensure the output of the program is in correct order), + // we first wait till the previous thread, whose details will be in "tail", + // is finished. + if (args->tail != NULL) { + args->tail->thread.join(); + } + + delete args->c; // delete the object "c". This may cause some output, + // e.g. to a stream. We don't need to worry about concurrent access to + // the output stream, because each thread waits for the previous thread + // to be done, before doing this. So there is no risk of concurrent + // access. + args->c = NULL; + + if (args->tail != NULL) { + KALDI_ASSERT(args->tail->tail == NULL); // Because we already + // did join on args->tail->thread, which means that + // thread was done, and before it exited, it would have + // deleted and set to NULL its tail (which is the next line of code). + delete args->tail; + args->tail = NULL; + } + // At this point we are exiting from the thread. Signal the + // "tot_threads_avail_" semaphore which is used to limit the total number of threads that are alive, including + // not onlhy those that are in active computation in c->operator (), but those + // that are waiting on I/O or other threads. + args->me->tot_threads_avail_.Signal(); + } + + int32 num_threads_; // copy of config.num_threads (since Semaphore doesn't store original count) + + Semaphore threads_avail_; // Initialized to the number of threads we are + // supposed to run with; the function Run() waits on this. + + Semaphore tot_threads_avail_; // We use this semaphore to ensure we don't + // consume too much memory... + RunTaskArgsList *thread_list_; + +}; + +} // namespace kaldi + +#endif // KALDI_THREAD_KALDI_THREAD_H_ diff --git a/speechx/speechx/kaldi/util/options-itf.h b/speechx/speechx/kaldi/util/options-itf.h new file mode 100644 index 00000000..204f46d6 --- /dev/null +++ b/speechx/speechx/kaldi/util/options-itf.h @@ -0,0 +1,49 @@ +// itf/options-itf.h + +// Copyright 2013 Tanel Alumae, Tallinn University of Technology + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_ITF_OPTIONS_ITF_H_ +#define KALDI_ITF_OPTIONS_ITF_H_ 1 +#include "base/kaldi-common.h" + +namespace kaldi { + +class OptionsItf { + public: + + virtual void Register(const std::string &name, + bool *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + int32 *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + uint32 *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + float *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + double *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + std::string *ptr, const std::string &doc) = 0; + + virtual ~OptionsItf() {} +}; + +} // namespace Kaldi + +#endif // KALDI_ITF_OPTIONS_ITF_H_ + + diff --git a/speechx/speechx/kaldi/util/parse-options.cc b/speechx/speechx/kaldi/util/parse-options.cc new file mode 100644 index 00000000..4b08ca39 --- /dev/null +++ b/speechx/speechx/kaldi/util/parse-options.cc @@ -0,0 +1,668 @@ +// util/parse-options.cc + +// Copyright 2009-2011 Karel Vesely; Microsoft Corporation; +// Saarland University (Author: Arnab Ghoshal); +// Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey); +// Frantisek Skala; Arnab Ghoshal +// Copyright 2013 Tanel Alumae +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "util/parse-options.h" +#include "util/text-utils.h" +#include "base/kaldi-common.h" + +namespace kaldi { + + +ParseOptions::ParseOptions(const std::string &prefix, + OptionsItf *other): + print_args_(false), help_(false), usage_(""), argc_(0), argv_(NULL) { + ParseOptions *po = dynamic_cast(other); + if (po != NULL && po->other_parser_ != NULL) { + // we get here if this constructor is used twice, recursively. + other_parser_ = po->other_parser_; + } else { + other_parser_ = other; + } + if (po != NULL && po->prefix_ != "") { + prefix_ = po->prefix_ + std::string(".") + prefix; + } else { + prefix_ = prefix; + } +} + +void ParseOptions::Register(const std::string &name, + bool *ptr, const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, + int32 *ptr, const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, + uint32 *ptr, const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, + float *ptr, const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, + double *ptr, const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, + std::string *ptr, const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +// old-style, used for registering application-specific parameters +template +void ParseOptions::RegisterTmpl(const std::string &name, T *ptr, + const std::string &doc) { + if (other_parser_ == NULL) { + this->RegisterCommon(name, ptr, doc, false); + } else { + KALDI_ASSERT(prefix_ != "" && + "Cannot use empty prefix when registering with prefix."); + std::string new_name = prefix_ + '.' + name; // name becomes prefix.name + other_parser_->Register(new_name, ptr, doc); + } +} + +// does the common part of the job of registering a parameter +template +void ParseOptions::RegisterCommon(const std::string &name, T *ptr, + const std::string &doc, bool is_standard) { + KALDI_ASSERT(ptr != NULL); + std::string idx = name; + NormalizeArgName(&idx); + if (doc_map_.find(idx) != doc_map_.end()) + KALDI_WARN << "Registering option twice, ignoring second time: " << name; + this->RegisterSpecific(name, idx, ptr, doc, is_standard); +} + +// used to register standard parameters (those that are present in all of the +// applications) +template +void ParseOptions::RegisterStandard(const std::string &name, T *ptr, + const std::string &doc) { + this->RegisterCommon(name, ptr, doc, true); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, + bool *b, + const std::string &doc, + bool is_standard) { + bool_map_[idx] = b; + doc_map_[idx] = DocInfo(name, doc + " (bool, default = " + + ((*b)? "true)" : "false)"), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, + int32 *i, + const std::string &doc, + bool is_standard) { + int_map_[idx] = i; + std::ostringstream ss; + ss << doc << " (int, default = " << *i << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, + uint32 *u, + const std::string &doc, + bool is_standard) { + uint_map_[idx] = u; + std::ostringstream ss; + ss << doc << " (uint, default = " << *u << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, + float *f, + const std::string &doc, + bool is_standard) { + float_map_[idx] = f; + std::ostringstream ss; + ss << doc << " (float, default = " << *f << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, + double *f, + const std::string &doc, + bool is_standard) { + double_map_[idx] = f; + std::ostringstream ss; + ss << doc << " (double, default = " << *f << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, + std::string *s, + const std::string &doc, + bool is_standard) { + string_map_[idx] = s; + doc_map_[idx] = DocInfo(name, doc + " (string, default = \"" + *s + "\")", + is_standard); +} +void ParseOptions::DisableOption(const std::string &name) { + if (argv_ != NULL) + KALDI_ERR << "DisableOption must not be called after calling Read()."; + if (doc_map_.erase(name) == 0) + KALDI_ERR << "Option " << name + << " was not registered so cannot be disabled: "; + bool_map_.erase(name); + int_map_.erase(name); + uint_map_.erase(name); + float_map_.erase(name); + double_map_.erase(name); + string_map_.erase(name); +} + + +int ParseOptions::NumArgs() const { + return positional_args_.size(); +} + +std::string ParseOptions::GetArg(int i) const { + // use KALDI_ERR if code error + if (i < 1 || i > static_cast(positional_args_.size())) + KALDI_ERR << "ParseOptions::GetArg, invalid index " << i; + return positional_args_[i - 1]; +} + +// We currently do not support any other options. +enum ShellType { kBash = 0 }; + +// This can be changed in the code if it ever does need to be changed (as it's +// unlikely that one compilation of this tool-set would use both shells). +static ShellType kShellType = kBash; + +// Returns true if we need to escape a string before putting it into +// a shell (mainly thinking of bash shell, but should work for others) +// This is for the convenience of the user so command-lines that are +// printed out by ParseOptions::Read (with --print-args=true) are +// paste-able into the shell and will run. If you use a different type of +// shell, it might be necessary to change this function. +// But it's mostly a cosmetic issue as it basically affects how +// the program echoes its command-line arguments to the screen. +static bool MustBeQuoted(const std::string &str, ShellType st) { + // Only Bash is supported (for the moment). + KALDI_ASSERT(st == kBash && "Invalid shell type."); + + const char *c = str.c_str(); + if (*c == '\0') { + return true; // Must quote empty string + } else { + const char *ok_chars[2]; + + // These seem not to be interpreted as long as there are no other "bad" + // characters involved (e.g. "," would be interpreted as part of something + // like a{b,c}, but not on its own. + ok_chars[kBash] = "[]~#^_-+=:.,/"; + + // Just want to make sure that a space character doesn't get automatically + // inserted here via an automated style-checking script, like it did before. + KALDI_ASSERT(!strchr(ok_chars[kBash], ' ')); + + for (; *c != '\0'; c++) { + // For non-alphanumeric characters we have a list of characters which + // are OK. All others are forbidden (this is easier since the shell + // interprets most non-alphanumeric characters). + if (!isalnum(*c)) { + const char *d; + for (d = ok_chars[st]; *d != '\0'; d++) if (*c == *d) break; + // If not alphanumeric or one of the "ok_chars", it must be escaped. + if (*d == '\0') return true; + } + } + return false; // The string was OK. No quoting or escaping. + } +} + +// Returns a quoted and escaped version of "str" +// which has previously been determined to need escaping. +// Our aim is to print out the command line in such a way that if it's +// pasted into a shell of ShellType "st" (only bash for now), it +// will get passed to the program in the same way. +static std::string QuoteAndEscape(const std::string &str, ShellType st) { + // Only Bash is supported (for the moment). + KALDI_ASSERT(st == kBash && "Invalid shell type."); + + // For now we use the following rules: + // In the normal case, we quote with single-quote "'", and to escape + // a single-quote we use the string: '\'' (interpreted as closing the + // single-quote, putting an escaped single-quote from the shell, and + // then reopening the single quote). + char quote_char = '\''; + const char *escape_str = "'\\''"; // e.g. echo 'a'\''b' returns a'b + + // If the string contains single-quotes that would need escaping this + // way, and we determine that the string could be safely double-quoted + // without requiring any escaping, then we double-quote the string. + // This is the case if the characters "`$\ do not appear in the string. + // e.g. see http://www.redhat.com/mirrors/LDP/LDP/abs/html/quotingvar.html + const char *c_str = str.c_str(); + if (strchr(c_str, '\'') && !strpbrk(c_str, "\"`$\\")) { + quote_char = '"'; + escape_str = "\\\""; // should never be accessed. + } + + char buf[2]; + buf[1] = '\0'; + + buf[0] = quote_char; + std::string ans = buf; + const char *c = str.c_str(); + for (;*c != '\0'; c++) { + if (*c == quote_char) { + ans += escape_str; + } else { + buf[0] = *c; + ans += buf; + } + } + buf[0] = quote_char; + ans += buf; + return ans; +} + +// static function +std::string ParseOptions::Escape(const std::string &str) { + return MustBeQuoted(str, kShellType) ? QuoteAndEscape(str, kShellType) : str; +} + + +int ParseOptions::Read(int argc, const char *const argv[]) { + argc_ = argc; + argv_ = argv; + std::string key, value; + int i; + if (argc > 0) { + // set global "const char*" g_program_name (name of the program) + // so it can be printed out in error messages; + // it's useful because often the stderr of different programs will + // be mixed together in the same log file. +#ifdef _MSC_VER + const char *c = strrchr(argv[0], '\\'); +#else + const char *c = strrchr(argv[0], '/'); +#endif + SetProgramName(c == NULL ? argv[0] : c + 1); + } + // first pass: look for config parameter, look for priority + for (i = 1; i < argc; i++) { + if (std::strncmp(argv[i], "--", 2) == 0) { + if (std::strcmp(argv[i], "--") == 0) { + // a lone "--" marks the end of named options + break; + } + bool has_equal_sign; + SplitLongArg(argv[i], &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (key.compare("config") == 0) { + ReadConfigFile(value); + } + if (key.compare("help") == 0) { + PrintUsage(); + exit(0); + } + } + } + bool double_dash_seen = false; + // second pass: add the command line options + for (i = 1; i < argc; i++) { + if (std::strncmp(argv[i], "--", 2) == 0) { + if (std::strcmp(argv[i], "--") == 0) { + // A lone "--" marks the end of named options. + // Skip that option and break the processing of named options + i += 1; + double_dash_seen = true; + break; + } + bool has_equal_sign; + SplitLongArg(argv[i], &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (!SetOption(key, value, has_equal_sign)) { + PrintUsage(true); + KALDI_ERR << "Invalid option " << argv[i]; + } + } else { + break; + } + } + + // process remaining arguments as positional + for (; i < argc; i++) { + if ((std::strcmp(argv[i], "--") == 0) && !double_dash_seen) { + double_dash_seen = true; + } else { + positional_args_.push_back(std::string(argv[i])); + } + } + + // if the user did not suppress this with --print-args = false.... + if (print_args_) { + std::ostringstream strm; + for (int j = 0; j < argc; j++) + strm << Escape(argv[j]) << " "; + strm << '\n'; + std::cerr << strm.str() << std::flush; + } + return i; +} + + +void ParseOptions::PrintUsage(bool print_command_line) { + std::cerr << '\n' << usage_ << '\n'; + DocMapType::iterator it; + // first we print application-specific options + bool app_specific_header_printed = false; + for (it = doc_map_.begin(); it != doc_map_.end(); ++it) { + if (it->second.is_standard_ == false) { // application-specific option + if (app_specific_header_printed == false) { // header was not yet printed + std::cerr << "Options:" << '\n'; + app_specific_header_printed = true; + } + std::cerr << " --" << std::setw(25) << std::left << it->second.name_ + << " : " << it->second.use_msg_ << '\n'; + } + } + if (app_specific_header_printed == true) { + std::cerr << '\n'; + } + + // then the standard options + std::cerr << "Standard options:" << '\n'; + for (it = doc_map_.begin(); it != doc_map_.end(); ++it) { + if (it->second.is_standard_ == true) { // we have standard option + std::cerr << " --" << std::setw(25) << std::left << it->second.name_ + << " : " << it->second.use_msg_ << '\n'; + } + } + std::cerr << '\n'; + if (print_command_line) { + std::ostringstream strm; + strm << "Command line was: "; + for (int j = 0; j < argc_; j++) + strm << Escape(argv_[j]) << " "; + strm << '\n'; + std::cerr << strm.str() << std::flush; + } +} + +void ParseOptions::PrintConfig(std::ostream &os) { + os << '\n' << "[[ Configuration of UI-Registered options ]]" << '\n'; + std::string key; + DocMapType::iterator it; + for (it = doc_map_.begin(); it != doc_map_.end(); ++it) { + key = it->first; + os << it->second.name_ << " = "; + if (bool_map_.end() != bool_map_.find(key)) { + os << (*bool_map_[key] ? "true" : "false"); + } else if (int_map_.end() != int_map_.find(key)) { + os << (*int_map_[key]); + } else if (uint_map_.end() != uint_map_.find(key)) { + os << (*uint_map_[key]); + } else if (float_map_.end() != float_map_.find(key)) { + os << (*float_map_[key]); + } else if (double_map_.end() != double_map_.find(key)) { + os << (*double_map_[key]); + } else if (string_map_.end() != string_map_.find(key)) { + os << "'" << *string_map_[key] << "'"; + } else { + KALDI_ERR << "PrintConfig: unrecognized option " << key << "[code error]"; + } + os << '\n'; + } + os << '\n'; +} + + +void ParseOptions::ReadConfigFile(const std::string &filename) { + std::ifstream is(filename.c_str(), std::ifstream::in); + if (!is.good()) { + KALDI_ERR << "Cannot open config file: " << filename; + } + + std::string line, key, value; + int32 line_number = 0; + while (std::getline(is, line)) { + line_number++; + // trim out the comments + size_t pos; + if ((pos = line.find_first_of('#')) != std::string::npos) { + line.erase(pos); + } + // skip empty lines + Trim(&line); + if (line.length() == 0) continue; + + if (line.substr(0, 2) != "--") { + KALDI_ERR << "Reading config file " << filename + << ": line " << line_number << " does not look like a line " + << "from a Kaldi command-line program's config file: should " + << "be of the form --x=y. Note: config files intended to " + << "be sourced by shell scripts lack the '--'."; + } + + // parse option + bool has_equal_sign; + SplitLongArg(line, &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (!SetOption(key, value, has_equal_sign)) { + PrintUsage(true); + KALDI_ERR << "Invalid option " << line << " in config file " << filename; + } + } +} + + + +void ParseOptions::SplitLongArg(const std::string &in, + std::string *key, + std::string *value, + bool *has_equal_sign) { + KALDI_ASSERT(in.substr(0, 2) == "--"); // precondition. + size_t pos = in.find_first_of('=', 0); + if (pos == std::string::npos) { // we allow --option for bools + // defaults to empty. We handle this differently in different cases. + *key = in.substr(2, in.size()-2); // 2 because starts with --. + *value = ""; + *has_equal_sign = false; + } else if (pos == 2) { // we also don't allow empty keys: --=value + PrintUsage(true); + KALDI_ERR << "Invalid option (no key): " << in; + } else { // normal case: --option=value + *key = in.substr(2, pos-2); // 2 because starts with --. + *value = in.substr(pos + 1); + *has_equal_sign = true; + } +} + + +void ParseOptions::NormalizeArgName(std::string *str) { + std::string out; + std::string::iterator it; + + for (it = str->begin(); it != str->end(); ++it) { + if (*it == '_') + out += '-'; // convert _ to - + else + out += std::tolower(*it); + } + *str = out; + + KALDI_ASSERT(str->length() > 0); +} + + + + +bool ParseOptions::SetOption(const std::string &key, + const std::string &value, + bool has_equal_sign) { + if (bool_map_.end() != bool_map_.find(key)) { + if (has_equal_sign && value == "") + KALDI_ERR << "Invalid option --" << key << "="; + *(bool_map_[key]) = ToBool(value); + } else if (int_map_.end() != int_map_.find(key)) { + *(int_map_[key]) = ToInt(value); + } else if (uint_map_.end() != uint_map_.find(key)) { + *(uint_map_[key]) = ToUint(value); + } else if (float_map_.end() != float_map_.find(key)) { + *(float_map_[key]) = ToFloat(value); + } else if (double_map_.end() != double_map_.find(key)) { + *(double_map_[key]) = ToDouble(value); + } else if (string_map_.end() != string_map_.find(key)) { + if (!has_equal_sign) + KALDI_ERR << "Invalid option --" << key + << " (option format is --x=y)."; + *(string_map_[key]) = value; + } else { + return false; + } + return true; +} + + + +bool ParseOptions::ToBool(std::string str) { + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + + // allow "" as a valid option for "true", so that --x is the same as --x=true + if ((str.compare("true") == 0) || (str.compare("t") == 0) + || (str.compare("1") == 0) || (str.compare("") == 0)) { + return true; + } + if ((str.compare("false") == 0) || (str.compare("f") == 0) + || (str.compare("0") == 0)) { + return false; + } + // if it is neither true nor false: + PrintUsage(true); + KALDI_ERR << "Invalid format for boolean argument [expected true or false]: " + << str; + return false; // never reached +} + + +int32 ParseOptions::ToInt(const std::string &str) { + int32 ret; + if (!ConvertStringToInteger(str, &ret)) + KALDI_ERR << "Invalid integer option \"" << str << "\""; + return ret; +} + +uint32 ParseOptions::ToUint(const std::string &str) { + uint32 ret; + if (!ConvertStringToInteger(str, &ret)) + KALDI_ERR << "Invalid integer option \"" << str << "\""; + return ret; +} + +float ParseOptions::ToFloat(const std::string &str) { + float ret; + if (!ConvertStringToReal(str, &ret)) + KALDI_ERR << "Invalid floating-point option \"" << str << "\""; + return ret; +} + +double ParseOptions::ToDouble(const std::string &str) { + double ret; + if (!ConvertStringToReal(str, &ret)) + KALDI_ERR << "Invalid floating-point option \"" << str << "\""; + return ret; +} + +// instantiate templates +template void ParseOptions::RegisterTmpl(const std::string &name, bool *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, int32 *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, uint32 *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, float *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, double *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, + std::string *ptr, const std::string &doc); + +template void ParseOptions::RegisterStandard(const std::string &name, + bool *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + int32 *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + uint32 *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + float *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + double *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + std::string *ptr, + const std::string &doc); + +template void ParseOptions::RegisterCommon(const std::string &name, + bool *ptr, + const std::string &doc, bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + int32 *ptr, + const std::string &doc, bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + uint32 *ptr, + const std::string &doc, bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + float *ptr, + const std::string &doc, bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + double *ptr, + const std::string &doc, bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + std::string *ptr, + const std::string &doc, bool is_standard); + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/util/parse-options.h b/speechx/speechx/kaldi/util/parse-options.h new file mode 100644 index 00000000..5e83f996 --- /dev/null +++ b/speechx/speechx/kaldi/util/parse-options.h @@ -0,0 +1,264 @@ +// util/parse-options.h + +// Copyright 2009-2011 Karel Vesely; Microsoft Corporation; +// Saarland University (Author: Arnab Ghoshal); +// Copyright 2012-2013 Frantisek Skala; Arnab Ghoshal + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_PARSE_OPTIONS_H_ +#define KALDI_UTIL_PARSE_OPTIONS_H_ + +#include +#include +#include + +#include "base/kaldi-common.h" +#include "util/options-itf.h" + +namespace kaldi { + +/// The class ParseOptions is for parsing command-line options; see +/// \ref parse_options for more documentation. +class ParseOptions : public OptionsItf { + public: + explicit ParseOptions(const char *usage) : + print_args_(true), help_(false), usage_(usage), argc_(0), argv_(NULL), + prefix_(""), other_parser_(NULL) { +#if !defined(_MSC_VER) && !defined(__CYGWIN__) // This is just a convenient place to set the stderr to line + setlinebuf(stderr); // buffering mode, since it's called at program start. +#endif // This helps ensure different programs' output is not mixed up. + RegisterStandard("config", &config_, "Configuration file to read (this " + "option may be repeated)"); + RegisterStandard("print-args", &print_args_, + "Print the command line arguments (to stderr)"); + RegisterStandard("help", &help_, "Print out usage message"); + RegisterStandard("verbose", &g_kaldi_verbose_level, + "Verbose level (higher->more logging)"); + } + + /** + This is a constructor for the special case where some options are + registered with a prefix to avoid conflicts. The object thus created will + only be used temporarily to register an options class with the original + options parser (which is passed as the *other pointer) using the given + prefix. It should not be used for any other purpose, and the prefix must + not be the empty string. It seems to be the least bad way of implementing + options with prefixes at this point. + Example of usage is: + ParseOptions po; // original ParseOptions object + ParseOptions po_mfcc("mfcc", &po); // object with prefix. + MfccOptions mfcc_opts; + mfcc_opts.Register(&po_mfcc); + The options will now get registered as, e.g., --mfcc.frame-shift=10.0 + instead of just --frame-shift=10.0 + */ + ParseOptions(const std::string &prefix, OptionsItf *other); + + ~ParseOptions() {} + + // Methods from the interface + void Register(const std::string &name, + bool *ptr, const std::string &doc); + void Register(const std::string &name, + int32 *ptr, const std::string &doc); + void Register(const std::string &name, + uint32 *ptr, const std::string &doc); + void Register(const std::string &name, + float *ptr, const std::string &doc); + void Register(const std::string &name, + double *ptr, const std::string &doc); + void Register(const std::string &name, + std::string *ptr, const std::string &doc); + + /// If called after registering an option and before calling + /// Read(), disables that option from being used. Will crash + /// at runtime if that option had not been registered. + void DisableOption(const std::string &name); + + /// This one is used for registering standard parameters of all the programs + template + void RegisterStandard(const std::string &name, + T *ptr, const std::string &doc); + + /** + Parses the command line options and fills the ParseOptions-registered + variables. This must be called after all the variables were registered!!! + + Initially the variables have implicit values, + then the config file values are set-up, + finally the command line values given. + Returns the first position in argv that was not used. + [typically not useful: use NumParams() and GetParam(). ] + */ + int Read(int argc, const char *const *argv); + + /// Prints the usage documentation [provided in the constructor]. + void PrintUsage(bool print_command_line = false); + /// Prints the actual configuration of all the registered variables + void PrintConfig(std::ostream &os); + + /// Reads the options values from a config file. Must be called after + /// registering all options. This is usually used internally after the + /// standard --config option is used, but it may also be called from a + /// program. + void ReadConfigFile(const std::string &filename); + + /// Number of positional parameters (c.f. argc-1). + int NumArgs() const; + + /// Returns one of the positional parameters; 1-based indexing for argc/argv + /// compatibility. Will crash if param is not >=1 and <=NumArgs(). + std::string GetArg(int param) const; + + std::string GetOptArg(int param) const { + return (param <= NumArgs() ? GetArg(param) : ""); + } + + /// The following function will return a possibly quoted and escaped + /// version of "str", according to the current shell. Currently + /// this is just hardwired to bash. It's useful for debug output. + static std::string Escape(const std::string &str); + + private: + /// Template to register various variable types, + /// used for program-specific parameters + template + void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc); + + // Following functions do just the datatype-specific part of the job + /// Register boolean variable + void RegisterSpecific(const std::string &name, const std::string &idx, + bool *b, const std::string &doc, bool is_standard); + /// Register int32 variable + void RegisterSpecific(const std::string &name, const std::string &idx, + int32 *i, const std::string &doc, bool is_standard); + /// Register unsinged int32 variable + void RegisterSpecific(const std::string &name, const std::string &idx, + uint32 *u, + const std::string &doc, bool is_standard); + /// Register float variable + void RegisterSpecific(const std::string &name, const std::string &idx, + float *f, const std::string &doc, bool is_standard); + /// Register double variable [useful as we change BaseFloat type]. + void RegisterSpecific(const std::string &name, const std::string &idx, + double *f, const std::string &doc, bool is_standard); + /// Register string variable + void RegisterSpecific(const std::string &name, const std::string &idx, + std::string *s, const std::string &doc, + bool is_standard); + + /// Does the actual job for both kinds of parameters + /// Does the common part of the job for all datatypes, + /// then calls RegisterSpecific + template + void RegisterCommon(const std::string &name, + T *ptr, const std::string &doc, bool is_standard); + + /// Set option with name "key" to "value"; will crash if can't do it. + /// "has_equal_sign" is used to allow --x for a boolean option x, + /// and --y=, for a string option y. + bool SetOption(const std::string &key, const std::string &value, + bool has_equal_sign); + + bool ToBool(std::string str); + int32 ToInt(const std::string &str); + uint32 ToUint(const std::string &str); + float ToFloat(const std::string &str); + double ToDouble(const std::string &str); + + // maps for option variables + std::map bool_map_; + std::map int_map_; + std::map uint_map_; + std::map float_map_; + std::map double_map_; + std::map string_map_; + + /** + Structure for options' documentation + */ + struct DocInfo { + DocInfo() {} + DocInfo(const std::string &name, const std::string &usemsg) + : name_(name), use_msg_(usemsg), is_standard_(false) {} + DocInfo(const std::string &name, const std::string &usemsg, + bool is_standard) + : name_(name), use_msg_(usemsg), is_standard_(is_standard) {} + + std::string name_; + std::string use_msg_; + bool is_standard_; + }; + typedef std::map DocMapType; + DocMapType doc_map_; ///< map for the documentation + + bool print_args_; ///< variable for the implicit --print-args parameter + bool help_; ///< variable for the implicit --help parameter + std::string config_; ///< variable for the implicit --config parameter + std::vector positional_args_; + const char *usage_; + int argc_; + const char *const *argv_; + + /// These members are not normally used. They are only used when the object + /// is constructed with a prefix + std::string prefix_; + OptionsItf *other_parser_; + protected: + /// SplitLongArg parses an argument of the form --a=b, --a=, or --a, + /// and sets "has_equal_sign" to true if an equals-sign was parsed.. + /// this is needed in order to correctly allow --x for a boolean option + /// x, and --y= for a string option y, and to disallow --x= and --y. + void SplitLongArg(const std::string &in, std::string *key, + std::string *value, bool *has_equal_sign); + + void NormalizeArgName(std::string *str); +}; + +/// This template is provided for convenience in reading config classes from +/// files; this is not the standard way to read configuration options, but may +/// occasionally be needed. This function assumes the config has a function +/// "void Register(OptionsItf *opts)" which it can call to register the +/// ParseOptions object. +template void ReadConfigFromFile(const std::string &config_filename, + C *c) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << config_filename << "'"; + ParseOptions po(usage_str.str().c_str()); + c->Register(&po); + po.ReadConfigFile(config_filename); +} + +/// This variant of the template ReadConfigFromFile is for if you need to read +/// two config classes from the same file. +template void ReadConfigsFromFile(const std::string &conf, + C1 *c1, C2 *c2) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << conf << "'"; + ParseOptions po(usage_str.str().c_str()); + c1->Register(&po); + c2->Register(&po); + po.ReadConfigFile(conf); +} + + + +} // namespace kaldi + +#endif // KALDI_UTIL_PARSE_OPTIONS_H_ diff --git a/speechx/speechx/kaldi/util/simple-io-funcs.cc b/speechx/speechx/kaldi/util/simple-io-funcs.cc new file mode 100644 index 00000000..cb732a10 --- /dev/null +++ b/speechx/speechx/kaldi/util/simple-io-funcs.cc @@ -0,0 +1,81 @@ +// util/simple-io-funcs.cc + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#include "util/simple-io-funcs.h" +#include "util/text-utils.h" + +namespace kaldi { + +bool WriteIntegerVectorSimple(const std::string &wxfilename, + const std::vector &list) { + kaldi::Output ko; + // false, false is: text-mode, no Kaldi header. + if (!ko.Open(wxfilename, false, false)) return false; + for (size_t i = 0; i < list.size(); i++) ko.Stream() << list[i] << '\n'; + return ko.Close(); +} + +bool ReadIntegerVectorSimple(const std::string &rxfilename, + std::vector *list) { + kaldi::Input ki; + if (!ki.OpenTextMode(rxfilename)) return false; + std::istream &is = ki.Stream(); + int32 i; + list->clear(); + while ( !(is >> i).fail() ) + list->push_back(i); + is >> std::ws; + return is.eof(); // should be eof, or junk at end of file. +} + +bool WriteIntegerVectorVectorSimple(const std::string &wxfilename, + const std::vector > &list) { + kaldi::Output ko; + // false, false is: text-mode, no Kaldi header. + if (!ko.Open(wxfilename, false, false)) return false; + std::ostream &os = ko.Stream(); + for (size_t i = 0; i < list.size(); i++) { + for (size_t j = 0; j < list[i].size(); j++) { + os << list[i][j]; + if (j+1 < list[i].size()) os << ' '; + } + os << '\n'; + } + return ko.Close(); +} + +bool ReadIntegerVectorVectorSimple(const std::string &rxfilename, + std::vector > *list) { + kaldi::Input ki; + if (!ki.OpenTextMode(rxfilename)) return false; + std::istream &is = ki.Stream(); + list->clear(); + std::string line; + while (std::getline(is, line)) { + std::vector v; + if (!SplitStringToIntegers(line, " \t\r", true, &v)) { + list->clear(); + return false; + } + list->push_back(v); + } + return is.eof(); // if we're not at EOF, something weird happened. +} + + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/util/simple-io-funcs.h b/speechx/speechx/kaldi/util/simple-io-funcs.h new file mode 100644 index 00000000..30b90acb --- /dev/null +++ b/speechx/speechx/kaldi/util/simple-io-funcs.h @@ -0,0 +1,63 @@ +// util/simple-io-funcs.h + +// Copyright 2009-2011 Microsoft Corporation; Jan Silovsky + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_SIMPLE_IO_FUNCS_H_ +#define KALDI_UTIL_SIMPLE_IO_FUNCS_H_ + +#include +#include +#include "util/kaldi-io.h" + +// This header contains some utilities for reading some common, simple text +// formats:integers in files, one per line, and integers in files, possibly +// multiple per line. these are not really fully native Kaldi formats; they are +// mostly for small files that might be generated by scripts, and can be read +// all at one time. for longer files of this type, we would probably use the +// Table code. + +namespace kaldi { + +/// WriteToList attempts to write this list of integers, one per line, +/// to the given file, in text format. +/// returns true if succeeded. +bool WriteIntegerVectorSimple(const std::string &wxfilename, + const std::vector &v); + +/// ReadFromList attempts to read this list of integers, one per line, +/// from the given file, in text format. +/// returns true if succeeded. +bool ReadIntegerVectorSimple(const std::string &rxfilename, + std::vector *v); + +// This is a file format like: +// 1 2 +// 3 +// +// 4 5 6 +// etc. +bool WriteIntegerVectorVectorSimple(const std::string &wxfilename, + const std::vector > &v); + +bool ReadIntegerVectorVectorSimple(const std::string &rxfilename, + std::vector > *v); + + +} // end namespace kaldi. + + +#endif // KALDI_UTIL_SIMPLE_IO_FUNCS_H_ diff --git a/speechx/speechx/kaldi/util/simple-options.cc b/speechx/speechx/kaldi/util/simple-options.cc new file mode 100644 index 00000000..592500e2 --- /dev/null +++ b/speechx/speechx/kaldi/util/simple-options.cc @@ -0,0 +1,184 @@ +// util/simple-options.cc + +// Copyright 2013 Tanel Alumae, Tallinn University of Technology + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "util/simple-options.h" + + +namespace kaldi { + +void SimpleOptions::Register(const std::string &name, + bool *value, + const std::string &doc) { + bool_map_[name] = value; + option_info_list_.push_back(std::make_pair(name, OptionInfo(doc, kBool))); +} + +void SimpleOptions::Register(const std::string &name, + int32 *value, + const std::string &doc) { + int_map_[name] = value; + option_info_list_.push_back(std::make_pair(name, OptionInfo(doc, kInt32))); +} + +void SimpleOptions::Register(const std::string &name, + uint32 *value, + const std::string &doc) { + uint_map_[name] = value; + option_info_list_.push_back(std::make_pair(name, OptionInfo(doc, kUint32))); +} + +void SimpleOptions::Register(const std::string &name, + float *value, + const std::string &doc) { + float_map_[name] = value; + option_info_list_.push_back(std::make_pair(name, OptionInfo(doc, kFloat))); +} + +void SimpleOptions::Register(const std::string &name, + double *value, + const std::string &doc) { + double_map_[name] = value; + option_info_list_.push_back(std::make_pair(name, OptionInfo(doc, kDouble))); +} + +void SimpleOptions::Register(const std::string &name, + std::string *value, + const std::string &doc) { + string_map_[name] = value; + option_info_list_.push_back(std::make_pair(name, OptionInfo(doc, kString))); +} + +template +static bool SetOptionImpl(const std::string &key, const T &value, + std::map &some_map) { + if (some_map.end() != some_map.find(key)) { + *(some_map[key]) = value; + return true; + } + return false; +} + +bool SimpleOptions::SetOption(const std::string &key, const bool &value) { + return SetOptionImpl(key, value, bool_map_); +} + +bool SimpleOptions::SetOption(const std::string &key, const int32 &value) { + if (!SetOptionImpl(key, value, int_map_)) { + if (!SetOptionImpl(key, static_cast(value), uint_map_)) { + return false; + } + } + return true; +} + +bool SimpleOptions::SetOption(const std::string &key, const uint32 &value) { + if (!SetOptionImpl(key, value, uint_map_)) { + if (!SetOptionImpl(key, static_cast(value), int_map_)) { + return false; + } + } + return true; +} + +bool SimpleOptions::SetOption(const std::string &key, const float &value) { + if (!SetOptionImpl(key, value, float_map_)) { + if (!SetOptionImpl(key, static_cast(value), double_map_)) { + return false; + } + } + return true; +} + +bool SimpleOptions::SetOption(const std::string &key, const double &value) { + if (!SetOptionImpl(key, value, double_map_)) { + if (!SetOptionImpl(key, static_cast(value), float_map_)) { + return false; + } + } + return true; +} + +bool SimpleOptions::SetOption(const std::string &key, + const std::string &value) { + return SetOptionImpl(key, value, string_map_); +} + +bool SimpleOptions::SetOption(const std::string &key, const char *value) { + std::string str_value = std::string(value); + return SetOptionImpl(key, str_value, string_map_); +} + + +template +static bool GetOptionImpl(const std::string &key, T *value, + std::map &some_map) { + typename std::map::iterator it = some_map.find(key); + if (it != some_map.end()) { + *value = *(it->second); + return true; + } + return false; +} + +bool SimpleOptions::GetOption(const std::string &key, bool *value) { + return GetOptionImpl(key, value, bool_map_); +} + +bool SimpleOptions::GetOption(const std::string &key, int32 *value) { + return GetOptionImpl(key, value, int_map_); +} + +bool SimpleOptions::GetOption(const std::string &key, uint32 *value) { + return GetOptionImpl(key, value, uint_map_); +} + +bool SimpleOptions::GetOption(const std::string &key, float *value) { + return GetOptionImpl(key, value, float_map_); +} + +bool SimpleOptions::GetOption(const std::string &key, double *value) { + return GetOptionImpl(key, value, double_map_); +} + +bool SimpleOptions::GetOption(const std::string &key, std::string *value) { + return GetOptionImpl(key, value, string_map_); +} + +std::vector > +SimpleOptions::GetOptionInfoList() { + return option_info_list_; +} + +bool SimpleOptions::GetOptionType(const std::string &key, OptionType *type) { + for (std::vector >::iterator dx = option_info_list_.begin(); + dx != option_info_list_.end(); dx++) { + std::pair info_pair = (*dx); + if (info_pair.first == key) { + *type = info_pair.second.type; + return true; + } + } + return false; +} + + + +} // namespace kaldi diff --git a/speechx/speechx/kaldi/util/simple-options.h b/speechx/speechx/kaldi/util/simple-options.h new file mode 100644 index 00000000..f301c7d6 --- /dev/null +++ b/speechx/speechx/kaldi/util/simple-options.h @@ -0,0 +1,113 @@ +// util/simple-options.h + +// Copyright 2013 Tanel Alumae, Tallinn University of Technology + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_SIMPLE_OPTIONS_H_ +#define KALDI_UTIL_SIMPLE_OPTIONS_H_ + +#include +#include +#include + +#include "base/kaldi-common.h" +#include "util/options-itf.h" + +namespace kaldi { + + +/// The class SimpleOptions is an implementation of OptionsItf that allows +/// setting and getting option values programmatically, i.e., via getter +/// and setter methods. It doesn't provide any command line parsing +/// functionality. +/// The class ParseOptions should be used for command-line options. +class SimpleOptions : public OptionsItf { + public: + SimpleOptions() { + } + + virtual ~SimpleOptions() { + } + + // Methods from the interface + void Register(const std::string &name, bool *ptr, const std::string &doc); + void Register(const std::string &name, int32 *ptr, const std::string &doc); + void Register(const std::string &name, uint32 *ptr, const std::string &doc); + void Register(const std::string &name, float *ptr, const std::string &doc); + void Register(const std::string &name, double *ptr, const std::string &doc); + void Register(const std::string &name, std::string *ptr, + const std::string &doc); + + // set option with the specified key, return true if successful + bool SetOption(const std::string &key, const bool &value); + bool SetOption(const std::string &key, const int32 &value); + bool SetOption(const std::string &key, const uint32 &value); + bool SetOption(const std::string &key, const float &value); + bool SetOption(const std::string &key, const double &value); + bool SetOption(const std::string &key, const std::string &value); + bool SetOption(const std::string &key, const char* value); + + // get option with the specified key and put to 'value', + // return true if successful + bool GetOption(const std::string &key, bool *value); + bool GetOption(const std::string &key, int32 *value); + bool GetOption(const std::string &key, uint32 *value); + bool GetOption(const std::string &key, float *value); + bool GetOption(const std::string &key, double *value); + bool GetOption(const std::string &key, std::string *value); + + enum OptionType { + kBool, + kInt32, + kUint32, + kFloat, + kDouble, + kString + }; + + struct OptionInfo { + OptionInfo(const std::string &doc, OptionType type) : + doc(doc), type(type) { + } + std::string doc; + OptionType type; + }; + + std::vector > GetOptionInfoList(); + + /* + * Puts the type of the option with name 'key' in the argument 'type'. + * Return true if such option is found, false otherwise. + */ + bool GetOptionType(const std::string &key, OptionType *type); + + private: + + std::vector > option_info_list_; + + // maps for option variables + std::map bool_map_; + std::map int_map_; + std::map uint_map_; + std::map float_map_; + std::map double_map_; + std::map string_map_; +}; + +} // namespace kaldi + +#endif // KALDI_UTIL_SIMPLE_OPTIONS_H_ diff --git a/speechx/speechx/kaldi/util/stl-utils.h b/speechx/speechx/kaldi/util/stl-utils.h new file mode 100644 index 00000000..647073a2 --- /dev/null +++ b/speechx/speechx/kaldi/util/stl-utils.h @@ -0,0 +1,317 @@ +// util/stl-utils.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_STL_UTILS_H_ +#define KALDI_UTIL_STL_UTILS_H_ + +#include +#include +using std::unordered_map; +using std::unordered_set; + +#include +#include +#include +#include +#include +#include "base/kaldi-common.h" + +namespace kaldi { + +/// Sorts and uniq's (removes duplicates) from a vector. +template +inline void SortAndUniq(std::vector *vec) { + std::sort(vec->begin(), vec->end()); + vec->erase(std::unique(vec->begin(), vec->end()), vec->end()); +} + + +/// Returns true if the vector is sorted. +template +inline bool IsSorted(const std::vector &vec) { + typename std::vector::const_iterator iter = vec.begin(), end = vec.end(); + if (iter == end) return true; + while (1) { + typename std::vector::const_iterator next_iter = iter; + ++next_iter; + if (next_iter == end) return true; // end of loop and nothing out of order + if (*next_iter < *iter) return false; + iter = next_iter; + } +} + + +/// Returns true if the vector is sorted and contains each element +/// only once. +template +inline bool IsSortedAndUniq(const std::vector &vec) { + typename std::vector::const_iterator iter = vec.begin(), end = vec.end(); + if (iter == end) return true; + while (1) { + typename std::vector::const_iterator next_iter = iter; + ++next_iter; + if (next_iter == end) return true; // end of loop and nothing out of order + if (*next_iter <= *iter) return false; + iter = next_iter; + } +} + + +/// Removes duplicate elements from a sorted list. +template +inline void Uniq(std::vector *vec) { // must be already sorted. + KALDI_PARANOID_ASSERT(IsSorted(*vec)); + KALDI_ASSERT(vec); + vec->erase(std::unique(vec->begin(), vec->end()), vec->end()); +} + +/// Copies the elements of a set to a vector. +template +void CopySetToVector(const std::set &s, std::vector *v) { + // copies members of s into v, in sorted order from lowest to highest + // (because the set was in sorted order). + KALDI_ASSERT(v != NULL); + v->resize(s.size()); + typename std::set::const_iterator siter = s.begin(), send = s.end(); + typename std::vector::iterator viter = v->begin(); + for (; siter != send; ++siter, ++viter) { + *viter = *siter; + } +} + +template +void CopySetToVector(const unordered_set &s, std::vector *v) { + KALDI_ASSERT(v != NULL); + v->resize(s.size()); + typename unordered_set::const_iterator siter = s.begin(), send = s.end(); + typename std::vector::iterator viter = v->begin(); + for (; siter != send; ++siter, ++viter) { + *viter = *siter; + } +} + + +/// Copies the (key, value) pairs in a map to a vector of pairs. +template +void CopyMapToVector(const std::map &m, + std::vector > *v) { + KALDI_ASSERT(v != NULL); + v->resize(m.size()); + typename std::map::const_iterator miter = m.begin(), mend = m.end(); + typename std::vector >::iterator viter = v->begin(); + for (; miter != mend; ++miter, ++viter) { + *viter = std::make_pair(miter->first, miter->second); + // do it like this because of const casting. + } +} + +/// Copies the keys in a map to a vector. +template +void CopyMapKeysToVector(const std::map &m, std::vector *v) { + KALDI_ASSERT(v != NULL); + v->resize(m.size()); + typename std::map::const_iterator miter = m.begin(), mend = m.end(); + typename std::vector::iterator viter = v->begin(); + for (; miter != mend; ++miter, ++viter) { + *viter = miter->first; + } +} + +/// Copies the values in a map to a vector. +template +void CopyMapValuesToVector(const std::map &m, std::vector *v) { + KALDI_ASSERT(v != NULL); + v->resize(m.size()); + typename std::map::const_iterator miter = m.begin(), mend = m.end(); + typename std::vector::iterator viter = v->begin(); + for (; miter != mend; ++miter, ++viter) { + *viter = miter->second; + } +} + +/// Copies the keys in a map to a set. +template +void CopyMapKeysToSet(const std::map &m, std::set *s) { + KALDI_ASSERT(s != NULL); + s->clear(); + typename std::map::const_iterator miter = m.begin(), mend = m.end(); + for (; miter != mend; ++miter) { + s->insert(s->end(), miter->first); + } +} + +/// Copies the values in a map to a set. +template +void CopyMapValuesToSet(const std::map &m, std::set *s) { + KALDI_ASSERT(s != NULL); + s->clear(); + typename std::map::const_iterator miter = m.begin(), mend = m.end(); + for (; miter != mend; ++miter) + s->insert(s->end(), miter->second); +} + + +/// Copies the contents of a vector to a set. +template +void CopyVectorToSet(const std::vector &v, std::set *s) { + KALDI_ASSERT(s != NULL); + s->clear(); + typename std::vector::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) + s->insert(s->end(), *iter); + // s->end() is a hint in case v was sorted. will work regardless. +} + +/// Deletes any non-NULL pointers in the vector v, and sets +/// the corresponding entries of v to NULL +template +void DeletePointers(std::vector *v) { + KALDI_ASSERT(v != NULL); + typename std::vector::iterator iter = v->begin(), end = v->end(); + for (; iter != end; ++iter) { + if (*iter != NULL) { + delete *iter; + *iter = NULL; // set to NULL for extra safety. + } + } +} + +/// Returns true if the vector of pointers contains NULL pointers. +template +bool ContainsNullPointers(const std::vector &v) { + typename std::vector::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) + if (*iter == static_cast (NULL)) return true; + return false; +} + +/// Copies the contents a vector of one type to a vector +/// of another type. +template +void CopyVectorToVector(const std::vector &vec_in, std::vector *vec_out) { + KALDI_ASSERT(vec_out != NULL); + vec_out->resize(vec_in.size()); + for (size_t i = 0; i < vec_in.size(); i++) + (*vec_out)[i] = static_cast (vec_in[i]); +} + +/// A hashing function-object for vectors. +template +struct VectorHasher { // hashing function for vector. + size_t operator()(const std::vector &x) const noexcept { + size_t ans = 0; + typename std::vector::const_iterator iter = x.begin(), end = x.end(); + for (; iter != end; ++iter) { + ans *= kPrime; + ans += *iter; + } + return ans; + } + VectorHasher() { // Check we're instantiated with an integer type. + KALDI_ASSERT_IS_INTEGER_TYPE(Int); + } + private: + static const int kPrime = 7853; +}; + +/// A hashing function-object for pairs of ints +template +struct PairHasher { // hashing function for pair + size_t operator()(const std::pair &x) const noexcept { + // 7853 was chosen at random from a list of primes. + return x.first + x.second * 7853; + } + PairHasher() { // Check we're instantiated with an integer type. + KALDI_ASSERT_IS_INTEGER_TYPE(Int1); + KALDI_ASSERT_IS_INTEGER_TYPE(Int2); + } +}; + + +/// A hashing function object for strings. +struct StringHasher { // hashing function for std::string + size_t operator()(const std::string &str) const noexcept { + size_t ans = 0, len = str.length(); + const char *c = str.c_str(), *end = c + len; + for (; c != end; c++) { + ans *= kPrime; + ans += *c; + } + return ans; + } + private: + static const int kPrime = 7853; +}; + +/// Reverses the contents of a vector. +template +inline void ReverseVector(std::vector *vec) { + KALDI_ASSERT(vec != NULL); + size_t sz = vec->size(); + for (size_t i = 0; i < sz/2; i++) + std::swap( (*vec)[i], (*vec)[sz-1-i]); +} + + +/// Comparator object for pairs that compares only the first pair. +template +struct CompareFirstMemberOfPair { + inline bool operator() (const std::pair &p1, + const std::pair &p2) { + return p1.first < p2.first; + } +}; + +/// For a vector of pair where I is an integer and F a floating-point or +/// integer type, this function sorts a vector of type vector > on +/// the I value and then merges elements with equal I values, summing these over +/// the F component and then removing any F component with zero value. This +/// is for where the vector of pairs represents a map from the integer to float +/// component, with an "adding" type of semantics for combining the elements. +template +inline void MergePairVectorSumming(std::vector > *vec) { + KALDI_ASSERT_IS_INTEGER_TYPE(I); + CompareFirstMemberOfPair c; + std::sort(vec->begin(), vec->end(), c); // sort on 1st element. + typename std::vector >::iterator out = vec->begin(), + in = vec->begin(), end = vec->end(); + // special case: while there is nothing to be changed, skip over + // initial input (avoids unnecessary copying). + while (in + 1 < end && in[0].first != in[1].first && in[0].second != 0.0) { + in++; + out++; + } + while (in < end) { + // We reach this point only at the first element of + // each stretch of identical .first elements. + *out = *in; + ++in; + while (in < end && in->first == out->first) { + out->second += in->second; // this is the merge operation. + ++in; + } + if (out->second != static_cast(0)) // Don't keep zero elements. + out++; + } + vec->erase(out, end); +} + +} // namespace kaldi + +#endif // KALDI_UTIL_STL_UTILS_H_ diff --git a/speechx/speechx/kaldi/util/table-types.h b/speechx/speechx/kaldi/util/table-types.h new file mode 100644 index 00000000..efcdf1b5 --- /dev/null +++ b/speechx/speechx/kaldi/util/table-types.h @@ -0,0 +1,192 @@ +// util/table-types.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_UTIL_TABLE_TYPES_H_ +#define KALDI_UTIL_TABLE_TYPES_H_ +#include "base/kaldi-common.h" +#include "util/kaldi-table.h" +#include "util/kaldi-holder.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { + +// This header defines typedefs that are specific instantiations of +// the Table types. + +/// \addtogroup table_types +/// @{ + +typedef TableWriter > > + BaseFloatMatrixWriter; +typedef SequentialTableReader > > + SequentialBaseFloatMatrixReader; +typedef RandomAccessTableReader > > + RandomAccessBaseFloatMatrixReader; +typedef RandomAccessTableReaderMapped > > + RandomAccessBaseFloatMatrixReaderMapped; + +typedef TableWriter > > + DoubleMatrixWriter; +typedef SequentialTableReader > > + SequentialDoubleMatrixReader; +typedef RandomAccessTableReader > > + RandomAccessDoubleMatrixReader; +typedef RandomAccessTableReaderMapped > > + RandomAccessDoubleMatrixReaderMapped; + +typedef TableWriter > + CompressedMatrixWriter; + +typedef TableWriter > > + BaseFloatVectorWriter; +typedef SequentialTableReader > > + SequentialBaseFloatVectorReader; +typedef RandomAccessTableReader > > + RandomAccessBaseFloatVectorReader; +typedef RandomAccessTableReaderMapped > > + RandomAccessBaseFloatVectorReaderMapped; + +typedef TableWriter > > + DoubleVectorWriter; +typedef SequentialTableReader > > + SequentialDoubleVectorReader; +typedef RandomAccessTableReader > > + RandomAccessDoubleVectorReader; + +typedef TableWriter > > + BaseFloatCuMatrixWriter; +typedef SequentialTableReader > > + SequentialBaseFloatCuMatrixReader; +typedef RandomAccessTableReader > > + RandomAccessBaseFloatCuMatrixReader; +typedef RandomAccessTableReaderMapped > > + RandomAccessBaseFloatCuMatrixReaderMapped; + +typedef TableWriter > > + DoubleCuMatrixWriter; +typedef SequentialTableReader > > + SequentialDoubleCuMatrixReader; +typedef RandomAccessTableReader > > + RandomAccessDoubleCuMatrixReader; +typedef RandomAccessTableReaderMapped > > + RandomAccessDoubleCuMatrixReaderMapped; + +typedef TableWriter > > + BaseFloatCuVectorWriter; +typedef SequentialTableReader > > + SequentialBaseFloatCuVectorReader; +typedef RandomAccessTableReader > > + RandomAccessBaseFloatCuVectorReader; +typedef RandomAccessTableReaderMapped > > + RandomAccessBaseFloatCuVectorReaderMapped; + +typedef TableWriter > > + DoubleCuVectorWriter; +typedef SequentialTableReader > > + SequentialDoubleCuVectorReader; +typedef RandomAccessTableReader > > + RandomAccessDoubleCuVectorReader; + + +typedef TableWriter > Int32Writer; +typedef SequentialTableReader > SequentialInt32Reader; +typedef RandomAccessTableReader > RandomAccessInt32Reader; + +typedef TableWriter > Int32VectorWriter; +typedef SequentialTableReader > + SequentialInt32VectorReader; +typedef RandomAccessTableReader > + RandomAccessInt32VectorReader; + +typedef TableWriter > Int32VectorVectorWriter; +typedef SequentialTableReader > + SequentialInt32VectorVectorReader; +typedef RandomAccessTableReader > + RandomAccessInt32VectorVectorReader; + +typedef TableWriter > Int32PairVectorWriter; +typedef SequentialTableReader > + SequentialInt32PairVectorReader; +typedef RandomAccessTableReader > + RandomAccessInt32PairVectorReader; + +typedef TableWriter > + BaseFloatPairVectorWriter; +typedef SequentialTableReader > + SequentialBaseFloatPairVectorReader; +typedef RandomAccessTableReader > + RandomAccessBaseFloatPairVectorReader; + +typedef TableWriter > BaseFloatWriter; +typedef SequentialTableReader > + SequentialBaseFloatReader; +typedef RandomAccessTableReader > + RandomAccessBaseFloatReader; +typedef RandomAccessTableReaderMapped > + RandomAccessBaseFloatReaderMapped; + +typedef TableWriter > DoubleWriter; +typedef SequentialTableReader > SequentialDoubleReader; +typedef RandomAccessTableReader > RandomAccessDoubleReader; + +typedef TableWriter > BoolWriter; +typedef SequentialTableReader > SequentialBoolReader; +typedef RandomAccessTableReader > RandomAccessBoolReader; + + + +/// TokenWriter is a writer specialized for std::string where the strings +/// are nonempty and whitespace-free. T == std::string +typedef TableWriter TokenWriter; +typedef SequentialTableReader SequentialTokenReader; +typedef RandomAccessTableReader RandomAccessTokenReader; + + +/// TokenVectorWriter is a writer specialized for sequences of +/// std::string where the strings are nonempty and whitespace-free. +/// T == std::vector +typedef TableWriter TokenVectorWriter; +// Ditto for SequentialTokenVectorReader. +typedef SequentialTableReader SequentialTokenVectorReader; +typedef RandomAccessTableReader + RandomAccessTokenVectorReader; + + +typedef TableWriter > + GeneralMatrixWriter; +typedef SequentialTableReader > + SequentialGeneralMatrixReader; +typedef RandomAccessTableReader > + RandomAccessGeneralMatrixReader; +typedef RandomAccessTableReaderMapped > + RandomAccessGeneralMatrixReaderMapped; + + + +/// @} + +// Note: for FST reader/writer, see ../fstext/fstext-utils.h +// [not done yet]. + +} // end namespace kaldi + + + +#endif // KALDI_UTIL_TABLE_TYPES_H_ diff --git a/speechx/speechx/kaldi/util/text-utils.cc b/speechx/speechx/kaldi/util/text-utils.cc new file mode 100644 index 00000000..bbf38ecc --- /dev/null +++ b/speechx/speechx/kaldi/util/text-utils.cc @@ -0,0 +1,591 @@ +// util/text-utils.cc + +// Copyright 2009-2011 Saarland University; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "util/text-utils.h" +#include +#include +#include +#include "base/kaldi-common.h" + +namespace kaldi { + + +template +bool SplitStringToFloats(const std::string &full, + const char *delim, + bool omit_empty_strings, // typically false + std::vector *out) { + KALDI_ASSERT(out != NULL); + if (*(full.c_str()) == '\0') { + out->clear(); + return true; + } + std::vector split; + SplitStringToVector(full, delim, omit_empty_strings, &split); + out->resize(split.size()); + for (size_t i = 0; i < split.size(); i++) { + F f = 0; + if (!ConvertStringToReal(split[i], &f)) + return false; + (*out)[i] = f; + } + return true; +} + +// Instantiate the template above for float and double. +template +bool SplitStringToFloats(const std::string &full, + const char *delim, + bool omit_empty_strings, + std::vector *out); +template +bool SplitStringToFloats(const std::string &full, + const char *delim, + bool omit_empty_strings, + std::vector *out); + +void SplitStringToVector(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out) { + size_t start = 0, found = 0, end = full.size(); + out->clear(); + while (found != std::string::npos) { + found = full.find_first_of(delim, start); + // start != end condition is for when the delimiter is at the end + if (!omit_empty_strings || (found != start && start != end)) + out->push_back(full.substr(start, found - start)); + start = found + 1; + } +} + +void JoinVectorToString(const std::vector &vec_in, + const char *delim, bool omit_empty_strings, + std::string *str_out) { + std::string tmp_str; + for (size_t i = 0; i < vec_in.size(); i++) { + if (!omit_empty_strings || !vec_in[i].empty()) { + tmp_str.append(vec_in[i]); + if (i < vec_in.size() - 1) + if (!omit_empty_strings || !vec_in[i+1].empty()) + tmp_str.append(delim); + } + } + str_out->swap(tmp_str); +} + +void Trim(std::string *str) { + const char *white_chars = " \t\n\r\f\v"; + + std::string::size_type pos = str->find_last_not_of(white_chars); + if (pos != std::string::npos) { + str->erase(pos + 1); + pos = str->find_first_not_of(white_chars); + if (pos != std::string::npos) str->erase(0, pos); + } else { + str->erase(str->begin(), str->end()); + } +} + +bool IsToken(const std::string &token) { + size_t l = token.length(); + if (l == 0) return false; + for (size_t i = 0; i < l; i++) { + unsigned char c = token[i]; + if ((!isprint(c) || isspace(c)) && (isascii(c) || c == (unsigned char)255)) + return false; + // The "&& (isascii(c) || c == 255)" was added so that we won't reject + // non-ASCII characters such as French characters with accents [except for + // 255 which is "nbsp", a form of space]. + } + return true; +} + + +void SplitStringOnFirstSpace(const std::string &str, + std::string *first, + std::string *rest) { + const char *white_chars = " \t\n\r\f\v"; + typedef std::string::size_type I; + const I npos = std::string::npos; + I first_nonwhite = str.find_first_not_of(white_chars); + if (first_nonwhite == npos) { + first->clear(); + rest->clear(); + return; + } + // next_white is first whitespace after first nonwhitespace. + I next_white = str.find_first_of(white_chars, first_nonwhite); + + if (next_white == npos) { // no more whitespace... + *first = std::string(str, first_nonwhite); + rest->clear(); + return; + } + I next_nonwhite = str.find_first_not_of(white_chars, next_white); + if (next_nonwhite == npos) { + *first = std::string(str, first_nonwhite, next_white-first_nonwhite); + rest->clear(); + return; + } + + I last_nonwhite = str.find_last_not_of(white_chars); + KALDI_ASSERT(last_nonwhite != npos); // or coding error. + + *first = std::string(str, first_nonwhite, next_white-first_nonwhite); + *rest = std::string(str, next_nonwhite, last_nonwhite+1-next_nonwhite); +} + +bool IsLine(const std::string &line) { + if (line.find('\n') != std::string::npos) return false; + if (line.empty()) return true; + if (isspace(*(line.begin()))) return false; + if (isspace(*(line.rbegin()))) return false; + std::string::const_iterator iter = line.begin(), end = line.end(); + for (; iter != end; iter++) + if (!isprint(*iter)) return false; + return true; +} + +template +class NumberIstream{ + public: + explicit NumberIstream(std::istream &i) : in_(i) {} + + NumberIstream & operator >> (T &x) { + if (!in_.good()) return *this; + in_ >> x; + if (!in_.fail() && RemainderIsOnlySpaces()) return *this; + return ParseOnFail(&x); + } + + private: + std::istream &in_; + + bool RemainderIsOnlySpaces() { + if (in_.tellg() != std::istream::pos_type(-1)) { + std::string rem; + in_ >> rem; + + if (rem.find_first_not_of(' ') != std::string::npos) { + // there is not only spaces + return false; + } + } + + in_.clear(); + return true; + } + + NumberIstream & ParseOnFail(T *x) { + std::string str; + in_.clear(); + in_.seekg(0); + // If the stream is broken even before trying + // to read from it or if there are many tokens, + // it's pointless to try. + if (!(in_ >> str) || !RemainderIsOnlySpaces()) { + in_.setstate(std::ios_base::failbit); + return *this; + } + + std::map inf_nan_map; + // we'll keep just uppercase values. + inf_nan_map["INF"] = std::numeric_limits::infinity(); + inf_nan_map["+INF"] = std::numeric_limits::infinity(); + inf_nan_map["-INF"] = - std::numeric_limits::infinity(); + inf_nan_map["INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["+INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["-INFINITY"] = - std::numeric_limits::infinity(); + inf_nan_map["NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["+NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-NAN"] = - std::numeric_limits::quiet_NaN(); + // MSVC + inf_nan_map["1.#INF"] = std::numeric_limits::infinity(); + inf_nan_map["-1.#INF"] = - std::numeric_limits::infinity(); + inf_nan_map["1.#QNAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-1.#QNAN"] = - std::numeric_limits::quiet_NaN(); + + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + + if (inf_nan_map.find(str) != inf_nan_map.end()) { + *x = inf_nan_map[str]; + } else { + in_.setstate(std::ios_base::failbit); + } + + return *this; + } +}; + + +template +bool ConvertStringToReal(const std::string &str, + T *out) { + std::istringstream iss(str); + + NumberIstream i(iss); + + i >> *out; + + if (iss.fail()) { + // Number conversion failed. + return false; + } + + return true; +} + +template +bool ConvertStringToReal(const std::string &str, + float *out); +template +bool ConvertStringToReal(const std::string &str, + double *out); + + + +/* + This function is a helper function of StringsApproxEqual. It should be + thought of as a recursive function-- it was designed that way-- but rather + than actually recursing (which would cause problems with stack overflow), we + just set the args and return to the start. + + The 'decimal_places_tolerance' argument is just passed in from outside, + see the documentation for StringsApproxEqual in text-utils.h to see an + explanation. The argument 'places_into_number' provides some information + about the strings 'a' and 'b' that precedes the current pointers. + For purposes of this comment, let's define the 'decimal' of a number + as the part that comes after the decimal point, e.g. in '99.123', + '123' would be the decimal. If 'places_into_number' is -1, it means + we're not currently inside some place like that (i.e. it's not the + case that we're pointing to the '1' or the '2' or the '3'). + If it's 0, then we'd be pointing to the first place after the decimal, + '1' in this case. Note if one of the numbers is shorter than the + other, like '99.123' versus '99.1234' and 'a' points to the first '3' + while 'b' points to the second '4', 'places_into_number' referes to the + shorter of the two, i.e. it would be 2 in this example. + + + */ +bool StringsApproxEqualInternal(const char *a, const char *b, + int32 decimal_places_tolerance, + int32 places_into_number) { +start: + char ca = *a, cb = *b; + if (ca == cb) { + if (ca == '\0') { + return true; + } else { + if (places_into_number >= 0) { + if (isdigit(ca)) { + places_into_number++; + } else { + places_into_number = -1; + } + } else { + if (ca == '.') { + places_into_number = 0; + } + } + a++; + b++; + goto start; + } + } else { + if (places_into_number >= decimal_places_tolerance && + (isdigit(ca) || isdigit(cb))) { + // we're potentially willing to accept this difference between the + // strings. + if (isdigit(ca)) a++; + if (isdigit(cb)) b++; + // we'll have advanced at least one of the two strings. + goto start; + } else if (places_into_number >= 0 && + ((ca == '0' && !isdigit(cb)) || (cb == '0' && !isdigit(ca)))) { + // this clause is designed to ensure that, for example, + // "0.1" would count the same as "0.100001". + if (ca == '0') a++; + else b++; + places_into_number++; + goto start; + } else { + return false; + } + } + +} + + +bool StringsApproxEqual(const std::string &a, + const std::string &b, + int32 decimal_places_tolerance) { + return StringsApproxEqualInternal(a.c_str(), b.c_str(), + decimal_places_tolerance, -1); +} + + +bool ConfigLine::ParseLine(const std::string &line) { + data_.clear(); + whole_line_ = line; + if (line.size() == 0) return false; // Empty line + size_t pos = 0, size = line.size(); + while (isspace(line[pos]) && pos < size) pos++; + if (pos == size) + return false; // whitespace-only line + size_t first_token_start_pos = pos; + // first get first_token_. + while (!isspace(line[pos]) && pos < size) { + if (line[pos] == '=') { + // If the first block of non-whitespace looks like "foo-bar=...", + // then we ignore it: there is no initial token, and FirstToken() + // is empty. + pos = first_token_start_pos; + break; + } + pos++; + } + first_token_ = std::string(line, first_token_start_pos, pos - first_token_start_pos); + // first_token_ is expected to be either empty or something like + // "component-node", which actually is a slightly more restrictive set of + // strings than IsValidName() checks for this is a convenient way to check it. + if (!first_token_.empty() && !IsValidName(first_token_)) + return false; + + while (pos < size) { + if (isspace(line[pos])) { + pos++; + continue; + } + + // OK, at this point we know that we are pointing at nonspace. + size_t next_equals_sign = line.find_first_of("=", pos); + if (next_equals_sign == pos || next_equals_sign == std::string::npos) { + // we're looking for something like 'key=value'. If there is no equals sign, + // or it's not preceded by something, it's a parsing failure. + return false; + } + std::string key(line, pos, next_equals_sign - pos); + if (!IsValidName(key)) return false; + + // handle any quotes. we support key='blah blah' or key="foo bar". + // no escaping is supported. + if (line[next_equals_sign+1] == '\'' || line[next_equals_sign+1] == '"') { + char my_quote = line[next_equals_sign+1]; + size_t next_quote = line.find_first_of(my_quote, next_equals_sign + 2); + if (next_quote == std::string::npos) { // no matching quote was found. + KALDI_WARN << "No matching quote for " << my_quote << " in config line '" + << line << "'"; + return false; + } else { + std::string value(line, next_equals_sign + 2, + next_quote - next_equals_sign - 2); + data_.insert(std::make_pair(key, std::make_pair(value, false))); + pos = next_quote + 1; + continue; + } + } else { + // we want to be able to parse something like "... input=Offset(a, -1) foo=bar": + // in general, config values with spaces in them, even without quoting. + + size_t next_next_equals_sign = line.find_first_of("=", next_equals_sign + 1), + terminating_space = size; + + if (next_next_equals_sign != std::string::npos) { // found a later equals sign. + size_t preceding_space = line.find_last_of(" \t", next_next_equals_sign); + if (preceding_space != std::string::npos && + preceding_space > next_equals_sign) + terminating_space = preceding_space; + } + while (isspace(line[terminating_space - 1]) && terminating_space > 0) + terminating_space--; + + std::string value(line, next_equals_sign + 1, + terminating_space - (next_equals_sign + 1)); + data_.insert(std::make_pair(key, std::make_pair(value, false))); + pos = terminating_space; + } + } + return true; +} + +bool ConfigLine::GetValue(const std::string &key, std::string *value) { + KALDI_ASSERT(value != NULL); + std::map >::iterator it = data_.begin(); + for (; it != data_.end(); ++it) { + if (it->first == key) { + *value = (it->second).first; + (it->second).second = true; + return true; + } + } + return false; +} + +bool ConfigLine::GetValue(const std::string &key, BaseFloat *value) { + KALDI_ASSERT(value != NULL); + std::map >::iterator it = data_.begin(); + for (; it != data_.end(); ++it) { + if (it->first == key) { + if (!ConvertStringToReal((it->second).first, value)) + return false; + (it->second).second = true; + return true; + } + } + return false; +} + +bool ConfigLine::GetValue(const std::string &key, int32 *value) { + KALDI_ASSERT(value != NULL); + std::map >::iterator it = data_.begin(); + for (; it != data_.end(); ++it) { + if (it->first == key) { + if (!ConvertStringToInteger((it->second).first, value)) + return false; + (it->second).second = true; + return true; + } + } + return false; +} + +bool ConfigLine::GetValue(const std::string &key, std::vector *value) { + KALDI_ASSERT(value != NULL); + value->clear(); + std::map >::iterator it = data_.begin(); + for (; it != data_.end(); ++it) { + if (it->first == key) { + if (!SplitStringToIntegers((it->second).first, ":,", true, value)) { + // KALDI_WARN << "Bad option " << (it->second).first; + return false; + } + (it->second).second = true; + return true; + } + } + return false; +} + +bool ConfigLine::GetValue(const std::string &key, bool *value) { + KALDI_ASSERT(value != NULL); + std::map >::iterator it = data_.begin(); + for (; it != data_.end(); ++it) { + if (it->first == key) { + if ((it->second).first.size() == 0) return false; + switch (((it->second).first)[0]) { + case 'F': + case 'f': + *value = false; + break; + case 'T': + case 't': + *value = true; + break; + default: + return false; + } + (it->second).second = true; + return true; + } + } + return false; +} + +bool ConfigLine::HasUnusedValues() const { + std::map >::const_iterator it = data_.begin(); + for (; it != data_.end(); ++it) { + if (!(it->second).second) return true; + } + return false; +} + +std::string ConfigLine::UnusedValues() const { + std::string unused_str; + std::map >::const_iterator it = data_.begin(); + for (; it != data_.end(); ++it) { + if (!(it->second).second) { + if (unused_str == "") + unused_str = it->first + "=" + (it->second).first; + else + unused_str += " " + it->first + "=" + (it->second).first; + } + } + return unused_str; +} + +// This is like ExpectToken but for two tokens, and it +// will either accept token1 and then token2, or just token2. +// This is useful in Read functions where the first token +// may already have been consumed. +void ExpectOneOrTwoTokens(std::istream &is, bool binary, + const std::string &token1, + const std::string &token2) { + KALDI_ASSERT(token1 != token2); + std::string temp; + ReadToken(is, binary, &temp); + if (temp == token1) { + ExpectToken(is, binary, token2); + } else { + if (temp != token2) { + KALDI_ERR << "Expecting token " << token1 << " or " << token2 + << " but got " << temp; + } + } +} + + +bool IsValidName(const std::string &name) { + if (name.size() == 0) return false; + for (size_t i = 0; i < name.size(); i++) { + if (i == 0 && !isalpha(name[i]) && name[i] != '_') + return false; + if (!isalnum(name[i]) && name[i] != '_' && name[i] != '-' && name[i] != '.') + return false; + } + return true; +} + +void ReadConfigLines(std::istream &is, + std::vector *lines) { + KALDI_ASSERT(lines != NULL); + std::string line; + while (std::getline(is, line)) { + if (line.size() == 0) continue; + size_t start = line.find_first_not_of(" \t"); + size_t end = line.find_first_of('#'); + if (start == std::string::npos || start == end) continue; + end = line.find_last_not_of(" \t", end - 1); + KALDI_ASSERT(end >= start); + lines->push_back(line.substr(start, end - start + 1)); + } +} + +void ParseConfigLines(const std::vector &lines, + std::vector *config_lines) { + config_lines->resize(lines.size()); + for (size_t i = 0; i < lines.size(); i++) { + bool ret = (*config_lines)[i].ParseLine(lines[i]); + if (!ret) { + KALDI_ERR << "Error parsing config line: " << lines[i]; + } + } +} + + +} // end namespace kaldi diff --git a/speechx/speechx/kaldi/util/text-utils.h b/speechx/speechx/kaldi/util/text-utils.h new file mode 100644 index 00000000..02f4bf48 --- /dev/null +++ b/speechx/speechx/kaldi/util/text-utils.h @@ -0,0 +1,281 @@ +// util/text-utils.h + +// Copyright 2009-2011 Saarland University; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_TEXT_UTILS_H_ +#define KALDI_UTIL_TEXT_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "base/kaldi-common.h" + + +namespace kaldi { + +/// Split a string using any of the single character delimiters. +/// If omit_empty_strings == true, the output will contain any +/// nonempty strings after splitting on any of the +/// characters in the delimiter. If omit_empty_strings == false, +/// the output will contain n+1 strings if there are n characters +/// in the set "delim" within the input string. In this case +/// the empty string is split to a single empty string. +void SplitStringToVector(const std::string &full, const char *delim, + bool omit_empty_strings, + std::vector *out); + +/// Joins the elements of a vector of strings into a single string using +/// "delim" as the delimiter. If omit_empty_strings == true, any empty strings +/// in the vector are skipped. A vector of empty strings results in an empty +/// string on the output. +void JoinVectorToString(const std::vector &vec_in, + const char *delim, bool omit_empty_strings, + std::string *str_out); + +/** + \brief Split a string (e.g. 1:2:3) into a vector of integers. + + \param [in] delim String containing a list of characters, any of which + is allowed as a delimiter. + \param [in] omit_empty_strings If true, empty strings between delimiters are + allowed and will not produce an output integer; if false, + instances of characters in 'delim' that are consecutive or + at the start or end of the string would be an error. + You'll normally want this to be true if 'delim' consists + of spaces, and false otherwise. + \param [out] out The output list of integers. +*/ +template +bool SplitStringToIntegers(const std::string &full, + const char *delim, + bool omit_empty_strings, // typically false [but + // should probably be true + // if "delim" is spaces]. + std::vector *out) { + KALDI_ASSERT(out != NULL); + KALDI_ASSERT_IS_INTEGER_TYPE(I); + if (*(full.c_str()) == '\0') { + out->clear(); + return true; + } + std::vector split; + SplitStringToVector(full, delim, omit_empty_strings, &split); + out->resize(split.size()); + for (size_t i = 0; i < split.size(); i++) { + const char *this_str = split[i].c_str(); + char *end = NULL; + int64 j = 0; + j = KALDI_STRTOLL(this_str, &end); + if (end == this_str || *end != '\0') { + out->clear(); + return false; + } else { + I jI = static_cast(j); + if (static_cast(jI) != j) { + // output type cannot fit this integer. + out->clear(); + return false; + } + (*out)[i] = jI; + } + } + return true; +} + +// This is defined for F = float and double. +template +bool SplitStringToFloats(const std::string &full, + const char *delim, + bool omit_empty_strings, // typically false + std::vector *out); + + +/// Converts a string into an integer via strtoll and returns false if there was +/// any kind of problem (i.e. the string was not an integer or contained extra +/// non-whitespace junk, or the integer was too large to fit into the type it is +/// being converted into). Only sets *out if everything was OK and it returns +/// true. +template +bool ConvertStringToInteger(const std::string &str, + Int *out) { + KALDI_ASSERT_IS_INTEGER_TYPE(Int); + const char *this_str = str.c_str(); + char *end = NULL; + errno = 0; + int64 i = KALDI_STRTOLL(this_str, &end); + if (end != this_str) + while (isspace(*end)) end++; + if (end == this_str || *end != '\0' || errno != 0) + return false; + Int iInt = static_cast(i); + if (static_cast(iInt) != i || + (i < 0 && !std::numeric_limits::is_signed)) { + return false; + } + *out = iInt; + return true; +} + + +/// ConvertStringToReal converts a string into either float or double +/// and returns false if there was any kind of problem (i.e. the string +/// was not a floating point number or contained extra non-whitespace junk). +/// Be careful- this function will successfully read inf's or nan's. +template +bool ConvertStringToReal(const std::string &str, + T *out); + +/// Removes the beginning and trailing whitespaces from a string +void Trim(std::string *str); + + +/// Removes leading and trailing white space from the string, then splits on the +/// first section of whitespace found (if present), putting the part before the +/// whitespace in "first" and the rest in "rest". If there is no such space, +/// everything that remains after removing leading and trailing whitespace goes +/// in "first". +void SplitStringOnFirstSpace(const std::string &line, + std::string *first, + std::string *rest); + + +/// Returns true if "token" is nonempty, and all characters are +/// printable and whitespace-free. +bool IsToken(const std::string &token); + + +/// Returns true if "line" is free of \n characters and unprintable +/// characters, and does not contain leading or trailing whitespace. +bool IsLine(const std::string &line); + + + +/** + This function returns true when two text strings are approximately equal, and + false when they are not. The definition of 'equal' is normal string + equality, except that two substrings like "0.31134" and "0.311341" would be + considered equal. 'decimal_places_tolerance' controls how many digits after + the '.' have to match up. + E.g. StringsApproxEqual("hello 0.23 there", "hello 0.24 there", 2) would + return false because there is a difference in the 2nd decimal, but with + an argument of 1 it would return true. + */ +bool StringsApproxEqual(const std::string &a, + const std::string &b, + int32 decimal_places_check = 2); + +/** + This class is responsible for parsing input like + hi-there xx=yyy a=b c empty= f-oo=Append(bar, sss) ba_z=123 bing='a b c' baz="a b c d='a b' e" + and giving you access to the fields, in this case + + FirstToken() == "hi-there", and key->value pairs: + + xx->yyy, a->"b c", empty->"", f-oo->"Append(bar, sss)", ba_z->"123", + bing->"a b c", baz->"a b c d='a b' e" + + The first token is optional, if the line started with a key-value pair then + FirstValue() will be empty. + + Note: it can parse value fields with space inside them only if they are free of the '=' + character. If values are going to contain the '=' character, you need to quote them + with either single or double quotes. + + Key values may contain -_a-zA-Z0-9, but must begin with a-zA-Z_. + */ +class ConfigLine { + public: + // Tries to parse the line as a config-file line. Returns false + // if it could not for some reason, e.g. parsing failure. In most cases + // prints no warnings; the user should do this. Does not expect comments. + bool ParseLine(const std::string &line); + + // the GetValue functions are overloaded for various types. They return true + // if the key exists with value that can be converted to that type, and false + // otherwise. They also mark the key-value pair as having been read. It is + // not an error to read values twice. + bool GetValue(const std::string &key, std::string *value); + bool GetValue(const std::string &key, BaseFloat *value); + bool GetValue(const std::string &key, int32 *value); + // Values may be separated by ":" or by ",". + bool GetValue(const std::string &key, std::vector *value); + bool GetValue(const std::string &key, bool *value); + + bool HasUnusedValues() const; + /// returns e.g. foo=bar xxx=yyy if foo and xxx were not consumed by one + /// of the GetValue() functions. + std::string UnusedValues() const; + + const std::string &FirstToken() const { return first_token_; } + + const std::string WholeLine() { return whole_line_; } + // use default assignment operator and copy constructor. + private: + std::string whole_line_; + // the first token of the line, e.g. if line is + // foo-bar baz=bing + // then first_token_ would be "foo-bar". + std::string first_token_; + + // data_ maps from key to (value, is-this-value-consumed?). + std::map > data_; + +}; + +/// This function is like ExpectToken but for two tokens, and it will either +/// accept token1 and then token2, or just token2. This is useful in Read +/// functions where the first token may already have been consumed. +void ExpectOneOrTwoTokens(std::istream &is, bool binary, + const std::string &token1, + const std::string &token2); + + +/** + This function reads in a config file and *appends* its contents to a vector of + lines; it is responsible for removing comments (anything after '#') and + stripping out any lines that contain only whitespace after comment removal. + */ +void ReadConfigLines(std::istream &is, + std::vector *lines); + + +/** + This function converts config-lines from a simple sequence of strings + as output by ReadConfigLines(), into a sequence of first-tokens and + name-value pairs. The general format is: + "command-type bar=baz xx=yyy" + etc., although there are subtleties as to what exactly is allowed, see + documentation for class ConfigLine for details. + This function will die if there was a parsing failure. + */ +void ParseConfigLines(const std::vector &lines, + std::vector *config_lines); + + +/// Returns true if 'name' would be a valid name for a component or node in a +/// nnet3Nnet. This is a nonempty string beginning with A-Za-z_, and containing only +/// '-', '_', '.', A-Z, a-z, or 0-9. +bool IsValidName(const std::string &name); + +} // namespace kaldi + +#endif // KALDI_UTIL_TEXT_UTILS_H_ diff --git a/speechx/speechx/third_party/README.md b/speechx/speechx/third_party/README.md new file mode 100644 index 00000000..2d620335 --- /dev/null +++ b/speechx/speechx/third_party/README.md @@ -0,0 +1,4 @@ +# third party + +Those libs copied and developed from third pary opensource software projects. +For all of these things, the official websites are the best place to go. -- GitLab