提交 902c35bd 编写于 作者: Y Yibing Liu

append some changes

...@@ -18,7 +18,7 @@ std::string ctc_greedy_decoder( ...@@ -18,7 +18,7 @@ std::string ctc_greedy_decoder(
const std::vector<std::string> &vocabulary) { const std::vector<std::string> &vocabulary) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; i++) { for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(), VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1, vocabulary.size() + 1,
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
...@@ -28,7 +28,7 @@ std::string ctc_greedy_decoder( ...@@ -28,7 +28,7 @@ std::string ctc_greedy_decoder(
size_t blank_id = vocabulary.size(); size_t blank_id = vocabulary.size();
std::vector<size_t> max_idx_vec; std::vector<size_t> max_idx_vec;
for (size_t i = 0; i < num_time_steps; i++) { for (size_t i = 0; i < num_time_steps; ++i) {
double max_prob = 0.0; double max_prob = 0.0;
size_t max_idx = 0; size_t max_idx = 0;
for (size_t j = 0; j < probs_seq[i].size(); j++) { for (size_t j = 0; j < probs_seq[i].size(); j++) {
...@@ -41,14 +41,14 @@ std::string ctc_greedy_decoder( ...@@ -41,14 +41,14 @@ std::string ctc_greedy_decoder(
} }
std::vector<size_t> idx_vec; std::vector<size_t> idx_vec;
for (size_t i = 0; i < max_idx_vec.size(); i++) { for (size_t i = 0; i < max_idx_vec.size(); ++i) {
if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) { if ((i == 0) || ((i > 0) && max_idx_vec[i] != max_idx_vec[i - 1])) {
idx_vec.push_back(max_idx_vec[i]); idx_vec.push_back(max_idx_vec[i]);
} }
} }
std::string best_path_result; std::string best_path_result;
for (size_t i = 0; i < idx_vec.size(); i++) { for (size_t i = 0; i < idx_vec.size(); ++i) {
if (idx_vec[i] != blank_id) { if (idx_vec[i] != blank_id) {
best_path_result += vocabulary[idx_vec[i]]; best_path_result += vocabulary[idx_vec[i]];
} }
...@@ -65,7 +65,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -65,7 +65,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
Scorer *ext_scorer) { Scorer *ext_scorer) {
// dimension check // dimension check
size_t num_time_steps = probs_seq.size(); size_t num_time_steps = probs_seq.size();
for (size_t i = 0; i < num_time_steps; i++) { for (size_t i = 0; i < num_time_steps; ++i) {
VALID_CHECK_EQ(probs_seq[i].size(), VALID_CHECK_EQ(probs_seq[i].size(),
vocabulary.size() + 1, vocabulary.size() + 1,
"The shape of probs_seq does not match with " "The shape of probs_seq does not match with "
...@@ -111,7 +111,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -111,7 +111,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
for (size_t time_step = 0; time_step < num_time_steps; time_step++) { for (size_t time_step = 0; time_step < num_time_steps; time_step++) {
std::vector<double> prob = probs_seq[time_step]; std::vector<double> prob = probs_seq[time_step];
std::vector<std::pair<int, double>> prob_idx; std::vector<std::pair<int, double>> prob_idx;
for (size_t i = 0; i < prob.size(); i++) { for (size_t i = 0; i < prob.size(); ++i) {
prob_idx.push_back(std::pair<int, double>(i, prob[i])); prob_idx.push_back(std::pair<int, double>(i, prob[i]));
} }
...@@ -134,7 +134,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -134,7 +134,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
if (cutoff_prob < 1.0) { if (cutoff_prob < 1.0) {
double cum_prob = 0.0; double cum_prob = 0.0;
cutoff_len = 0; cutoff_len = 0;
for (size_t i = 0; i < prob_idx.size(); i++) { for (size_t i = 0; i < prob_idx.size(); ++i) {
cum_prob += prob_idx[i].second; cum_prob += prob_idx[i].second;
cutoff_len += 1; cutoff_len += 1;
if (cum_prob >= cutoff_prob) break; if (cum_prob >= cutoff_prob) break;
...@@ -145,7 +145,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -145,7 +145,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prob_idx.begin(), prob_idx.begin() + cutoff_len); prob_idx.begin(), prob_idx.begin() + cutoff_len);
} }
std::vector<std::pair<size_t, float>> log_prob_idx; std::vector<std::pair<size_t, float>> log_prob_idx;
for (size_t i = 0; i < cutoff_len; i++) { for (size_t i = 0; i < cutoff_len; ++i) {
log_prob_idx.push_back(std::pair<int, float>( log_prob_idx.push_back(std::pair<int, float>(
prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN))); prob_idx[i].first, log(prob_idx[i].second + NUM_FLT_MIN)));
} }
...@@ -155,7 +155,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -155,7 +155,7 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
auto c = log_prob_idx[index].first; auto c = log_prob_idx[index].first;
float log_prob_c = log_prob_idx[index].second; float log_prob_c = log_prob_idx[index].second;
for (size_t i = 0; i < prefixes.size() && i < beam_size; i++) { for (size_t i = 0; i < prefixes.size() && i < beam_size; ++i) {
auto prefix = prefixes[i]; auto prefix = prefixes[i];
if (full_beam && log_prob_c + prefix->score < min_cutoff) { if (full_beam && log_prob_c + prefix->score < min_cutoff) {
...@@ -222,14 +222,14 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -222,14 +222,14 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
prefixes.end(), prefixes.end(),
prefix_compare); prefix_compare);
for (size_t i = beam_size; i < prefixes.size(); i++) { for (size_t i = beam_size; i < prefixes.size(); ++i) {
prefixes[i]->remove(); prefixes[i]->remove();
} }
} }
} // end of loop over time } // end of loop over time
// compute aproximate ctc score as the return score // compute aproximate ctc score as the return score
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
double approx_ctc = prefixes[i]->score; double approx_ctc = prefixes[i]->score;
if (ext_scorer != nullptr) { if (ext_scorer != nullptr) {
...@@ -249,14 +249,14 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder( ...@@ -249,14 +249,14 @@ std::vector<std::pair<double, std::string>> ctc_beam_search_decoder(
// allow for the post processing // allow for the post processing
std::vector<PathTrie *> space_prefixes; std::vector<PathTrie *> space_prefixes;
if (space_prefixes.empty()) { if (space_prefixes.empty()) {
for (size_t i = 0; i < beam_size && i < prefixes.size(); i++) { for (size_t i = 0; i < beam_size && i < prefixes.size(); ++i) {
space_prefixes.push_back(prefixes[i]); space_prefixes.push_back(prefixes[i]);
} }
} }
std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare); std::sort(space_prefixes.begin(), space_prefixes.end(), prefix_compare);
std::vector<std::pair<double, std::string>> output_vecs; std::vector<std::pair<double, std::string>> output_vecs;
for (size_t i = 0; i < beam_size && i < space_prefixes.size(); i++) { for (size_t i = 0; i < beam_size && i < space_prefixes.size(); ++i) {
std::vector<int> output; std::vector<int> output;
space_prefixes[i]->get_path_vec(output); space_prefixes[i]->get_path_vec(output);
// convert index to string // convert index to string
...@@ -301,7 +301,7 @@ ctc_beam_search_decoder_batch( ...@@ -301,7 +301,7 @@ ctc_beam_search_decoder_batch(
// enqueue the tasks of decoding // enqueue the tasks of decoding
std::vector<std::future<std::vector<std::pair<double, std::string>>>> res; std::vector<std::future<std::vector<std::pair<double, std::string>>>> res;
for (size_t i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; ++i) {
res.emplace_back(pool.enqueue(ctc_beam_search_decoder, res.emplace_back(pool.enqueue(ctc_beam_search_decoder,
probs_split[i], probs_split[i],
beam_size, beam_size,
...@@ -313,7 +313,7 @@ ctc_beam_search_decoder_batch( ...@@ -313,7 +313,7 @@ ctc_beam_search_decoder_batch(
// get decoding results // get decoding results
std::vector<std::vector<std::pair<double, std::string>>> batch_results; std::vector<std::vector<std::pair<double, std::string>>> batch_results;
for (size_t i = 0; i < batch_size; i++) { for (size_t i = 0; i < batch_size; ++i) {
batch_results.emplace_back(res[i].get()); batch_results.emplace_back(res[i].get());
} }
return batch_results; return batch_results;
......
#! /usr/bin/bash
source ../../utils/utility.sh
URL='http://cloud.dlnel.org/filepub/?uuid=6c83b9d8-3255-4adf-9726-0fe0be3d0274'
MD5=28521a58552885a81cf92a1e9b133a71
TARGET=./aishell_model.tar.gz
echo "Download Aishell model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download Aishell model!"
exit 1
fi
tar -zxvf $TARGET
exit 0
...@@ -2,9 +2,8 @@ ...@@ -2,9 +2,8 @@
source ../../utils/utility.sh source ../../utils/utility.sh
# TODO: add urls URL='http://cloud.dlnel.org/filepub/?uuid=17404caf-cf19-492f-9707-1fad07c19aae'
URL='to-be-added' MD5=ea5024a457a91179472f6dfee60e053d
MD5=5b4af224b26c1dc4dd972b7d32f2f52a
TARGET=./librispeech_model.tar.gz TARGET=./librispeech_model.tar.gz
......
#! /usr/bin/bash
source ../../utils/utility.sh
URL=http://cloud.dlnel.org/filepub/?uuid=d21861e4-4ed6-45bb-ad8e-ae417a43195e
MD5="29e02312deb2e59b3c8686c7966d4fe3"
TARGET=./zh_giga.no_cna_cmn.prune01244.klm
echo "Download language model ..."
download $URL $MD5 $TARGET
if [ $? -ne 0 ]; then
echo "Fail to download the language model!"
exit 1
fi
exit 0
...@@ -2,4 +2,3 @@ scipy==0.13.1 ...@@ -2,4 +2,3 @@ scipy==0.13.1
resampy==0.1.5 resampy==0.1.5
SoundFile==0.9.0.post1 SoundFile==0.9.0.post1
python_speech_features python_speech_features
https://github.com/luotao1/kenlm/archive/master.zip
...@@ -11,10 +11,9 @@ download() { ...@@ -11,10 +11,9 @@ download() {
fi fi
fi fi
wget -c $URL -P `dirname "$TARGET"` wget -c $URL -O "$TARGET"
md5_result=`md5sum $TARGET | awk -F[' '] '{print $1}'` md5_result=`md5sum $TARGET | awk -F[' '] '{print $1}'`
if [ ! $MD5 == $md5_result ]; then if [ ! $MD5 == $md5_result ]; then
echo "Fail to download the language model!"
return 1 return 1
fi fi
} }
...@@ -35,6 +35,8 @@ class ExternalMemory(object): ...@@ -35,6 +35,8 @@ class ExternalMemory(object):
sequence layer has sequence length indicating the number sequence layer has sequence length indicating the number
of memory slots, and size as memory slot size. of memory slots, and size as memory slot size.
:type boot_layer: LayerOutput :type boot_layer: LayerOutput
:param initial_weight: Initializer for addressing weights.
:type initial_weight: LayerOutput
:param readonly: If true, the memory is read-only, and write function cannot :param readonly: If true, the memory is read-only, and write function cannot
be called. Default is false. be called. Default is false.
:type readonly: bool :type readonly: bool
...@@ -49,6 +51,7 @@ class ExternalMemory(object): ...@@ -49,6 +51,7 @@ class ExternalMemory(object):
name, name,
mem_slot_size, mem_slot_size,
boot_layer, boot_layer,
initial_weight,
readonly=False, readonly=False,
enable_interpolation=True): enable_interpolation=True):
self.name = name self.name = name
...@@ -57,11 +60,7 @@ class ExternalMemory(object): ...@@ -57,11 +60,7 @@ class ExternalMemory(object):
self.enable_interpolation = enable_interpolation self.enable_interpolation = enable_interpolation
self.external_memory = paddle.layer.memory( self.external_memory = paddle.layer.memory(
name=self.name, size=self.mem_slot_size, boot_layer=boot_layer) name=self.name, size=self.mem_slot_size, boot_layer=boot_layer)
# prepare a constant (zero) intializer for addressing weights self.initial_weight = initial_weight
self.zero_addressing_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=boot_layer, size=1),
slope=0.0,
intercept=0.0)
# set memory to constant when readonly=True # set memory to constant when readonly=True
if self.readonly: if self.readonly:
self.updated_external_memory = paddle.layer.mixed( self.updated_external_memory = paddle.layer.mixed(
...@@ -111,7 +110,7 @@ class ExternalMemory(object): ...@@ -111,7 +110,7 @@ class ExternalMemory(object):
last_addressing_weight = paddle.layer.memory( last_addressing_weight = paddle.layer.memory(
name=self.name + "_addressing_weight_" + head_name, name=self.name + "_addressing_weight_" + head_name,
size=1, size=1,
boot_layer=self.zero_addressing_init) boot_layer=self.initial_weight)
interpolated_weight = paddle.layer.interpolation( interpolated_weight = paddle.layer.interpolation(
name=self.name + "_addressing_weight_" + head_name, name=self.name + "_addressing_weight_" + head_name,
input=[addressing_weight, addressing_weight], input=[addressing_weight, addressing_weight],
......
...@@ -125,7 +125,15 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size, ...@@ -125,7 +125,15 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
bounded_memory_perturbation bounded_memory_perturbation
], ],
act=paddle.activation.Linear()) act=paddle.activation.Linear())
bounded_memory_weight_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=bounded_memory_init, size=1),
slope=0.0,
intercept=0.0)
unbounded_memory_init = source_context unbounded_memory_init = source_context
unbounded_memory_weight_init = paddle.layer.slope_intercept(
input=paddle.layer.fc(input=unbounded_memory_init, size=1),
slope=0.0,
intercept=0.0)
# prepare step function for reccurent group # prepare step function for reccurent group
def recurrent_decoder_step(cur_embedding): def recurrent_decoder_step(cur_embedding):
...@@ -136,12 +144,14 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size, ...@@ -136,12 +144,14 @@ def memory_enhanced_decoder(input, target, initial_state, source_context, size,
name="bounded_memory", name="bounded_memory",
mem_slot_size=size, mem_slot_size=size,
boot_layer=bounded_memory_init, boot_layer=bounded_memory_init,
initial_weight=bounded_memory_weight_init,
readonly=False, readonly=False,
enable_interpolation=True) enable_interpolation=True)
unbounded_memory = ExternalMemory( unbounded_memory = ExternalMemory(
name="unbounded_memory", name="unbounded_memory",
mem_slot_size=size * 2, mem_slot_size=size * 2,
boot_layer=unbounded_memory_init, boot_layer=unbounded_memory_init,
initial_weight=unbounded_memory_weight_init,
readonly=True, readonly=True,
enable_interpolation=False) enable_interpolation=False)
# write bounded memory # write bounded memory
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册