提交 e1fc57de 编写于 作者: H Hui Zhang

add math and rename ds2 nnet

上级 75c57880
......@@ -14,19 +14,24 @@
#pragma once
#include <algorithm>
#include <condition_variable>
#include <cstring>
#include <deque>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <istream>
#include <map>
#include <memory>
#include <mutex>
#include <numeric>
#include <ostream>
#include <queue>
#include <set>
#include <sstream>
#include <stack>
#include <stdexcept>
#include <string>
#include <thread>
#include <unordered_map>
......
......@@ -20,7 +20,7 @@
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
#include "nnet/ds2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
......
......@@ -20,7 +20,7 @@
#include "decoder/ctc_tlg_decoder.h"
#include "frontend/audio/feature_pipeline.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
#include "nnet/ds2_nnet.h"
namespace ppspeech {
......
......@@ -20,7 +20,7 @@
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
#include "nnet/ds2_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(result_wspecifier, "", "test result wspecifier");
......
......@@ -2,13 +2,11 @@ project(nnet)
add_library(nnet STATIC
decodable.cc
paddle_nnet.cc
ds2_nnet.cc
)
target_link_libraries(nnet absl::strings)
set(bin_name nnet_forward_main)
set(bin_name ds2_nnet_main)
add_executable(${bin_name} ${CMAKE_CURRENT_SOURCE_DIR}/${bin_name}.cc)
target_include_directories(${bin_name} PRIVATE ${SPEECHX_ROOT} ${SPEECHX_ROOT}/kaldi)
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet ${DEPS})
target_link_libraries(${bin_name} utils kaldi-util kaldi-matrix gflags glog nnet ${DEPS})
\ No newline at end of file
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/paddle_nnet.h"
#include "nnet/ds2_nnet.h"
#include "absl/strings/str_split.h"
namespace ppspeech {
......
......@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "nnet/ds2_nnet.h"
#include "base/flags.h"
#include "base/log.h"
#include "frontend/audio/assembler.h"
#include "frontend/audio/data_cache.h"
#include "kaldi/util/table-types.h"
#include "nnet/decodable.h"
#include "nnet/paddle_nnet.h"
DEFINE_string(feature_rspecifier, "", "test feature rspecifier");
DEFINE_string(nnet_prob_wspecifier, "", "nnet porb wspecifier");
......
// Copyright (c) 2021 Mobvoi Inc (Zhendong Peng)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,10 +18,10 @@
#include "base/common.h"
#include <cmath>
#include <algorithm>
#include <utility>
#include <cmath>
#include <queue>
#include <utility>
namespace ppspeech {
......@@ -36,28 +37,36 @@ float LogSumExp(float x, float y) {
// greater compare for smallest priority_queue
template <typename T>
struct ValGreaterComp {
bool operator()(const std::pair<T, int32_t>& lhs, const std::pair<T, int32_>& rhs) const {
return lhs.first > rhs.first || (lhs.first == rhs.first && lhs.second < rhs.second);
bool operator()(const std::pair<T, int32_t>& lhs,
const std::pair<T, int32_>& rhs) const {
return lhs.first > rhs.first ||
(lhs.first == rhs.first && lhs.second < rhs.second);
}
}
template<typename T>
void TopK(const std::vector<T>& data, int32_t k, std::vector<T>* values, std::vector<int>* indices) {
int n = data.size();
int min_k_n = std::min(k, n);
template <typename T>
void TopK(const std::vector<T>& data,
int32_t k,
std::vector<T>* values,
std::vector<int>* indices) {
int n = data.size();
int min_k_n = std::min(k, n);
// smallest heap, (val, idx)
std::vector<std::pair<T, int32_t>> smallest_heap;
for (int i = 0; i < min_k_n; i++){
std::vector<std::pair<T, int32_t>> smallest_heap;
for (int i = 0; i < min_k_n; i++) {
smallest_heap.emplace_back(data[i], i);
}
// smallest priority_queue
std::priority_queue<std::pair<T, int32_t>, std::vector<std::pair<T, int32_t>>, ValGreaterComp<T>> pq(ValGreaterComp<T>(), std::move(smallest_heap));
std::priority_queue<std::pair<T, int32_t>,
std::vector<std::pair<T, int32_t>>,
ValGreaterComp<T>>
pq(ValGreaterComp<T>(), std::move(smallest_heap));
// top k
for (int i = k ; i < n; i++){
if (pq.top().first < data[i]){
for (int i = k; i < n; i++) {
if (pq.top().first < data[i]) {
pq.pop();
pq.emplace_back(data[i], i);
}
......@@ -68,7 +77,7 @@ void TopK(const std::vector<T>& data, int32_t k, std::vector<T>* values, std::ve
// from largest to samllest
int cur = values->size() - 1;
while(!pq.empty()){
while (!pq.empty()) {
const auto& item = pq.top();
pq.pop();
......
......@@ -14,15 +14,18 @@
#pragma once
#include <vector>
#include <cstdint>
#include <vector>
namespace ppspeech {
// Sum in log scale
float LogSumExp(float x, float y);
template<typename T>
void TopK(const std::vector<T>& data, int32_t k, std::vector<T>* values, std::vector<int>* indices);
template <typename T>
void TopK(const std::vector<T>& data,
int32_t k,
std::vector<T>* values,
std::vector<int>* indices);
} // namespace ppspeech
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册