提交 c462ab1a 编写于 作者: Y Yibing Liu

Refine infer_by_ckpt: code clean & move out cer scoring

上级 cf064d6c
decode_to_path=./decoding_result.txt
export CUDA_VISIBLE_DEVICES=2,3,4,5
python -u ../../infer_by_ckpt.py --batch_size 96 \
--checkpoint checkpoints/deep_asr.pass_20.checkpoint \
--infer_feature_lst data/test_feature.lst \
--infer_label_lst data/test_label.lst \
--mean_var data/aishell/global_mean_var \
--frame_dim 80 \
--class_num 3040 \
--target_trans data/text.test \
--num_threads 24 \
--decode_to_path $decode_to_path \
--trans_model mapped_decoder_data/exp/tri5a/final.mdl \
--log_prior mapped_decoder_data/logprior \
--vocabulary mapped_decoder_data/exp/tri5a/graph/words.txt \
......
ref_txt=data/text.test
hyp_txt=decoding_result.txt
python ../../score_error_rate.py --error_rate_type cer --ref $ref_txt --hyp $hyp_txt
......@@ -14,10 +14,9 @@ import data_utils.augmentor.trans_add_delta as trans_add_delta
import data_utils.augmentor.trans_splice as trans_splice
import data_utils.augmentor.trans_delay as trans_delay
import data_utils.async_data_reader as reader
from decoder.post_latgen_faster_mapped import Decoder
from data_utils.util import lodtensor_to_ndarray
from data_utils.util import lodtensor_to_ndarray, split_infer_result
from model_utils.model import stacked_lstmp_model
from data_utils.util import split_infer_result
from decoder.post_latgen_faster_mapped import Decoder
from tools.error_rate import char_errors
......@@ -64,11 +63,6 @@ def parse_args():
type=int,
default=10,
help='The number of threads for decoding. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
default=0.00016,
help='Learning rate used to train. (default: %(default)f)')
parser.add_argument(
'--device',
type=str,
......@@ -80,7 +74,7 @@ def parse_args():
parser.add_argument(
'--mean_var',
type=str,
default='data/global_mean_var_search26kHr',
default='data/global_mean_var',
help="The path for feature's global mean and variance. "
"(default: %(default)s)")
parser.add_argument(
......@@ -88,16 +82,6 @@ def parse_args():
type=str,
default='data/infer_feature.lst',
help='The feature list path for inference. (default: %(default)s)')
parser.add_argument(
'--infer_label_lst',
type=str,
default='data/infer_label.lst',
help='The label list path for inference. (default: %(default)s)')
parser.add_argument(
'--ref_txt',
type=str,
default='data/text.test',
help='The reference text for decoding. (default: %(default)s)')
parser.add_argument(
'--checkpoint',
type=str,
......@@ -128,16 +112,17 @@ def parse_args():
type=float,
default=0.2,
help="Scaling factor for acoustic likelihoods. (default: %(default)f)")
parser.add_argument(
'--target_trans',
type=str,
default="./decoder/target_trans.txt",
help="The path to target transcription. (default: %(default)s)")
parser.add_argument(
'--post_matrix_path',
type=str,
default=None,
help="The path to output post prob matrix. (default: %(default)s)")
parser.add_argument(
'--decode_to_path',
type=str,
default='./decoding_result.txt',
required=True,
help="The path to output the decoding result. (default: %(default)s)")
args = parser.parse_args()
return args
......@@ -149,26 +134,47 @@ def print_arguments(args):
print('------------------------------------------------')
def get_trg_trans(args):
trans_dict = {}
with open(args.target_trans) as trg_trans:
line = trg_trans.readline()
while line:
items = line.strip().split()
key = items[0]
trans_dict[key] = ''.join(items[1:])
line = trg_trans.readline()
return trans_dict
class PostMatrixWriter:
""" The writer for outputing the post probability matrix
"""
def __init__(self, to_path):
self._to_path = to_path
with open(self._to_path, "w") as post_matrix:
post_matrix.seek(0)
post_matrix.truncate()
def write(self, keys, probs):
with open(self._to_path, "a") as post_matrix:
if isinstance(keys, str):
keys, probs = [keys], [probs]
for key, prob in zip(keys, probs):
post_matrix.write(key + " [\n")
for i in range(prob.shape[0]):
for j in range(prob.shape[1]):
post_matrix.write(str(prob[i][j]) + " ")
post_matrix.write("\n")
post_matrix.write("]\n")
class DecodingResultWriter:
""" The writer for writing out decoding results
"""
def out_post_matrix(key, prob):
with open(args.post_matrix_path, "a") as post_matrix:
post_matrix.write(key + " [\n")
for i in range(prob.shape[0]):
for j in range(prob.shape[1]):
post_matrix.write(str(prob[i][j]) + " ")
post_matrix.write("\n")
post_matrix.write("]\n")
def __init__(self, to_path):
self._to_path = to_path
with open(self._to_path, "w") as decoding_result:
decoding_result.seek(0)
decoding_result.truncate()
def write(self, results):
with open(self._to_path, "a") as decoding_result:
if isinstance(results, str):
decoding_result.write(results.encode("utf8") + "\n")
else:
for result in results:
decoding_result.write(result.encode("utf8") + "\n")
def infer_from_ckpt(args):
......@@ -187,9 +193,10 @@ def infer_from_ckpt(args):
infer_program = fluid.default_main_program().clone()
# optimizer, placeholder
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.exponential_decay(
learning_rate=args.learning_rate,
learning_rate=0.0001,
decay_steps=1879,
decay_rate=1 / 1.2,
staircase=True))
......@@ -199,7 +206,6 @@ def infer_from_ckpt(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
#trg_trans = get_trg_trans(args)
# load checkpoint.
fluid.io.load_persistables(exe, args.checkpoint)
......@@ -218,13 +224,13 @@ def infer_from_ckpt(args):
# infer data reader
infer_data_reader = reader.AsyncDataReader(
args.infer_feature_lst,
args.infer_label_lst,
drop_frame_len=-1,
split_sentence_threshold=-1)
args.infer_feature_lst, drop_frame_len=-1, split_sentence_threshold=-1)
infer_data_reader.set_transformers(ltrans)
infer_costs, infer_accs = [], []
total_edit_dist, total_ref_len = 0.0, 0
decoding_result_writer = DecodingResultWriter(args.decode_to_path)
post_matrix_writer = None if args.post_matrix_path is None \
else PostMatrixWriter(args.post_matrix_path)
for batch_id, batch_data in enumerate(
infer_data_reader.batch_iterator(args.batch_size,
args.minimum_batch_size)):
......@@ -242,31 +248,17 @@ def infer_from_ckpt(args):
"label": label_t},
fetch_list=[prediction, avg_cost, accuracy],
return_numpy=False)
infer_costs.append(lodtensor_to_ndarray(results[1])[0])
infer_accs.append(lodtensor_to_ndarray(results[2])[0])
probs, lod = lodtensor_to_ndarray(results[0])
infer_batch = split_infer_result(probs, lod)
print("Decoding batch %d ..." % batch_id)
decoded = decoder.decode_batch(name_lst, infer_batch, args.num_threads)
for res in decoded:
print(res.encode("utf8"))
decoding_result_writer.write(decoded)
if args.post_matrix_path is not None:
for index, sample in enumerate(infer_batch):
key = name_lst[index]
out_post_matrix(key, sample)
'''
hyp = decoder.decode(key, sample)
edit_dist, ref_len = char_errors(ref.decode("utf8"), hyp)
total_edit_dist += edit_dist
total_ref_len += ref_len
print(key + "|Ref:", ref)
print(key + "|Hyp:", hyp.encode("utf8"))
print("Instance CER: ", edit_dist / ref_len)
'''
#print("batch: ", batch_id)
#print("Total CER = %f" % (total_edit_dist / total_ref_len))
post_matrix_writer.write(name_lst, infer_batch)
if __name__ == '__main__':
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
from tools.error_rate import char_errors, word_errors
def parse_args():
parser = argparse.ArgumentParser(
"Score word/character error rate (WER/CER) "
"for decoding result.")
parser.add_argument(
'--error_rate_type',
type=str,
default='cer',
choices=['cer', 'wer'],
help="Error rate type. (default: %(default)s)")
parser.add_argument(
'--ref', type=str, required=True, help="The ground truth text.")
parser.add_argument(
'--hyp', type=str, required=True, help="The decoding result.")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
ref_dict = {}
sum_errors, sum_ref_len = 0.0, 0
sent_cnt, not_in_ref_cnt = 0, 0
with open(args.ref, "r") as ref_txt:
line = ref_txt.readline()
while line:
del_pos = line.find(" ")
key, sent = line[0:del_pos], line[del_pos + 1:-1].strip()
ref_dict[key] = sent
line = ref_txt.readline()
with open(args.hyp, "r") as hyp_txt:
line = hyp_txt.readline()
while line:
del_pos = line.find(" ")
key, sent = line[0:del_pos], line[del_pos + 1:-1].strip()
sent_cnt += 1
line = hyp_txt.readline()
if key not in ref_dict:
not_in_ref_cnt += 1
continue
if args.error_rate_type == 'cer':
errors, ref_len = char_errors(
ref_dict[key].decode("utf8"),
sent.decode("utf8"),
remove_space=True)
else:
errors, ref_len = word_errors(ref_dict[key].decode("utf8"),
sent.decode("utf8"))
sum_errors += errors
sum_ref_len += ref_len
print("Error rate[%s] = %f (%d/%d)," %
(args.error_rate_type, sum_errors / sum_ref_len, int(sum_errors),
sum_ref_len))
print("total %d sentences in hyp, %d not presented in ref." %
(sent_cnt, not_in_ref_cnt))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册