From c7d9b11529561382039c70a57777a8f0e5e024d6 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 21 Apr 2022 11:27:26 +0000 Subject: [PATCH] format --- .flake8 | 2 + paddlespeech/cli/asr/infer.py | 6 +- paddlespeech/s2t/models/u2/u2.py | 2 +- paddlespeech/s2t/modules/ctc.py | 2 +- paddlespeech/server/README.md | 2 +- paddlespeech/server/README_cn.md | 2 +- .../server/bin/paddlespeech_client.py | 1 + .../server/engine/asr/online/ctc_search.py | 2 + .../tests/asr/online/websocket_client.py | 4 +- paddlespeech/t2s/exps/synthesize.py | 2 +- paddlespeech/vector/cluster/diarization.py | 2 +- .../ngram/zh/local/text_to_lexicon.py | 16 +- speechx/examples/text_lm/local/mmseg.py | 638 ++++++------ speechx/examples/wfst/README.md | 2 +- utils/DER.py | 2 +- utils/compute-wer.py | 964 +++++++++--------- utils/format_rsl.py | 77 +- utils/fst/prepare_dict.py | 7 +- 18 files changed, 910 insertions(+), 823 deletions(-) diff --git a/.flake8 b/.flake8 index 44685f23..6b50de7e 100644 --- a/.flake8 +++ b/.flake8 @@ -12,6 +12,8 @@ exclude = .git, # python cache __pycache__, + # third party + utils/compute-wer.py, third_party/, # Provide a comma-separate list of glob patterns to include for checks. filename = diff --git a/paddlespeech/cli/asr/infer.py b/paddlespeech/cli/asr/infer.py index 49dd7b35..97a1b321 100644 --- a/paddlespeech/cli/asr/infer.py +++ b/paddlespeech/cli/asr/infer.py @@ -40,6 +40,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig __all__ = ['ASRExecutor'] + @cli_register( name='paddlespeech.asr', description='Speech to text infer command.') class ASRExecutor(BaseExecutor): @@ -148,7 +149,7 @@ class ASRExecutor(BaseExecutor): os.path.dirname(os.path.abspath(self.cfg_path))) logger.info(self.cfg_path) logger.info(self.ckpt_path) - + #Init body. self.config = CfgNode(new_allowed=True) self.config.merge_from_file(self.cfg_path) @@ -278,7 +279,8 @@ class ASRExecutor(BaseExecutor): self._outputs["result"] = result_transcripts[0] elif "conformer" in model_type or "transformer" in model_type: - logger.info(f"we will use the transformer like model : {model_type}") + logger.info( + f"we will use the transformer like model : {model_type}") try: result_transcripts = self.model.decode( audio, diff --git a/paddlespeech/s2t/models/u2/u2.py b/paddlespeech/s2t/models/u2/u2.py index 9b66126e..530840d0 100644 --- a/paddlespeech/s2t/models/u2/u2.py +++ b/paddlespeech/s2t/models/u2/u2.py @@ -279,7 +279,7 @@ class U2BaseModel(ASRInterface, nn.Layer): # TODO(Hui Zhang): if end_flag.sum() == running_size: if end_flag.cast(paddle.int64).sum() == running_size: break - + # 2.1 Forward decoder step hyps_mask = subsequent_mask(i).unsqueeze(0).repeat( running_size, 1, 1).to(device) # (B*N, i, i) diff --git a/paddlespeech/s2t/modules/ctc.py b/paddlespeech/s2t/modules/ctc.py index 1bb15873..33ad472d 100644 --- a/paddlespeech/s2t/modules/ctc.py +++ b/paddlespeech/s2t/modules/ctc.py @@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase): # init once if self._ext_scorer is not None: return - + if language_model_path != '': logger.info("begin to initialize the external scorer " "for decoding") diff --git a/paddlespeech/server/README.md b/paddlespeech/server/README.md index 3ac68dae..8f140e4e 100644 --- a/paddlespeech/server/README.md +++ b/paddlespeech/server/README.md @@ -47,4 +47,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml ``` paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav -``` \ No newline at end of file +``` diff --git a/paddlespeech/server/README_cn.md b/paddlespeech/server/README_cn.md index 5f235313..91df9817 100644 --- a/paddlespeech/server/README_cn.md +++ b/paddlespeech/server/README_cn.md @@ -48,4 +48,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml ``` paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav -``` \ No newline at end of file +``` diff --git a/paddlespeech/server/bin/paddlespeech_client.py b/paddlespeech/server/bin/paddlespeech_client.py index 45469178..3ea14ab3 100644 --- a/paddlespeech/server/bin/paddlespeech_client.py +++ b/paddlespeech/server/bin/paddlespeech_client.py @@ -305,6 +305,7 @@ class ASRClientExecutor(BaseExecutor): return res['asr_results'] + @cli_client_register( name='paddlespeech_client.cls', description='visit cls service') class CLSClientExecutor(BaseExecutor): diff --git a/paddlespeech/server/engine/asr/online/ctc_search.py b/paddlespeech/server/engine/asr/online/ctc_search.py index 8aee0a50..be5fb15b 100644 --- a/paddlespeech/server/engine/asr/online/ctc_search.py +++ b/paddlespeech/server/engine/asr/online/ctc_search.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict + import paddle + from paddlespeech.cli.log import logger from paddlespeech.s2t.utils.utility import log_add diff --git a/paddlespeech/server/tests/asr/online/websocket_client.py b/paddlespeech/server/tests/asr/online/websocket_client.py index 49cbd703..015698f5 100644 --- a/paddlespeech/server/tests/asr/online/websocket_client.py +++ b/paddlespeech/server/tests/asr/online/websocket_client.py @@ -36,7 +36,7 @@ class ASRAudioHandler: x_len = len(samples) chunk_size = 85 * 16 #80ms, sample_rate = 16kHz - if x_len % chunk_size!= 0: + if x_len % chunk_size != 0: padding_len_x = chunk_size - x_len % chunk_size else: padding_len_x = 0 @@ -92,7 +92,7 @@ class ASRAudioHandler: separators=(',', ': ')) await ws.send(audio_info) msg = await ws.recv() - + # decode the bytes to str msg = json.loads(msg) logging.info("final receive msg={}".format(msg)) diff --git a/paddlespeech/t2s/exps/synthesize.py b/paddlespeech/t2s/exps/synthesize.py index dd66e54e..0855a6a2 100644 --- a/paddlespeech/t2s/exps/synthesize.py +++ b/paddlespeech/t2s/exps/synthesize.py @@ -52,7 +52,7 @@ def evaluate(args): # acoustic model am_name = args.am[:args.am.rindex('_')] am_dataset = args.am[args.am.rindex('_') + 1:] - + am_inference = get_am_inference( am=args.am, am_config=am_config, diff --git a/paddlespeech/vector/cluster/diarization.py b/paddlespeech/vector/cluster/diarization.py index a8043c22..b47b3f24 100644 --- a/paddlespeech/vector/cluster/diarization.py +++ b/paddlespeech/vector/cluster/diarization.py @@ -20,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement. import argparse import copy import warnings -from distutils.util import strtobool import numpy as np import scipy import sklearn +from distutils.util import strtobool from scipy import linalg from scipy import sparse from scipy.sparse.csgraph import connected_components diff --git a/speechx/examples/ngram/zh/local/text_to_lexicon.py b/speechx/examples/ngram/zh/local/text_to_lexicon.py index 0ccd07c7..ba5ab60a 100755 --- a/speechx/examples/ngram/zh/local/text_to_lexicon.py +++ b/speechx/examples/ngram/zh/local/text_to_lexicon.py @@ -2,6 +2,7 @@ import argparse from collections import Counter + def main(args): counter = Counter() with open(args.text, 'r') as fin, open(args.lexicon, 'w') as fout: @@ -12,7 +13,7 @@ def main(args): words = text.split() else: words = line.split() - + counter.update(words) for word in counter: @@ -20,21 +21,16 @@ def main(args): fout.write(f"{word}\t{val}\n") fout.flush() + if __name__ == '__main__': parser = argparse.ArgumentParser( description='text(line:utt1 中国 人) to lexicon(line:中国 中 国).') parser.add_argument( - '--has_key', - default=True, - help='text path, with utt or not') + '--has_key', default=True, help='text path, with utt or not') parser.add_argument( - '--text', - required=True, - help='text path. line: utt1 中国 人 or 中国 人') + '--text', required=True, help='text path. line: utt1 中国 人 or 中国 人') parser.add_argument( - '--lexicon', - required=True, - help='lexicon path. line:中国 中 国') + '--lexicon', required=True, help='lexicon path. line:中国 中 国') args = parser.parse_args() print(args) diff --git a/speechx/examples/text_lm/local/mmseg.py b/speechx/examples/text_lm/local/mmseg.py index 9b94ac31..74295cd3 100755 --- a/speechx/examples/text_lm/local/mmseg.py +++ b/speechx/examples/text_lm/local/mmseg.py @@ -1,305 +1,315 @@ #!/usr/bin/env python3 - # modify from https://sites.google.com/site/homepageoffuyanwei/Home/remarksandexcellentdiscussion/page-2 -class Word: - def __init__(self,text = '',freq = 0): - self.text = text - self.freq = freq - self.length = len(text) - + +class Word: + def __init__(self, text='', freq=0): + self.text = text + self.freq = freq + self.length = len(text) + + class Chunk: - def __init__(self,w1,w2 = None,w3 = None): - self.words = [] - self.words.append(w1) - if w2: - self.words.append(w2) - if w3: - self.words.append(w3) - + def __init__(self, w1, w2=None, w3=None): + self.words = [] + self.words.append(w1) + if w2: + self.words.append(w2) + if w3: + self.words.append(w3) + #计算chunk的总长度 - def totalWordLength(self): - length = 0 - for word in self.words: - length += len(word.text) - return length - + def totalWordLength(self): + length = 0 + for word in self.words: + length += len(word.text) + return length + #计算平均长度 - def averageWordLength(self): - return float(self.totalWordLength()) / float(len(self.words)) - + def averageWordLength(self): + return float(self.totalWordLength()) / float(len(self.words)) + #计算标准差 - def standardDeviation(self): - average = self.averageWordLength() - sum = 0.0 - for word in self.words: - tmp = (len(word.text) - average) - sum += float(tmp) * float(tmp) - return sum - + def standardDeviation(self): + average = self.averageWordLength() + sum = 0.0 + for word in self.words: + tmp = (len(word.text) - average) + sum += float(tmp) * float(tmp) + return sum + #自由语素度 def wordFrequency(self): - sum = 0 - for word in self.words: - sum += word.freq - return sum - -class ComplexCompare: - + sum = 0 + for word in self.words: + sum += word.freq + return sum + + +class ComplexCompare: def takeHightest(self, chunks, comparator): - i = 1 - for j in range(1, len(chunks)): - rlt = comparator(chunks[j], chunks[0]) - if rlt > 0: - i = 0 - if rlt >= 0: - chunks[i], chunks[j] = chunks[j], chunks[i] - i += 1 + i = 1 + for j in range(1, len(chunks)): + rlt = comparator(chunks[j], chunks[0]) + if rlt > 0: + i = 0 + if rlt >= 0: + chunks[i], chunks[j] = chunks[j], chunks[i] + i += 1 return chunks[0:i] - + #以下四个函数是mmseg算法的四种过滤原则,核心算法 def mmFilter(self, chunks): - def comparator(a,b): - return a.totalWordLength() - b.totalWordLength() - return self.takeHightest(chunks, comparator) - - def lawlFilter(self,chunks): - def comparator(a,b): - return a.averageWordLength() - b.averageWordLength() - return self.takeHightest(chunks,comparator) - - def svmlFilter(self,chunks): - def comparator(a,b): - return b.standardDeviation() - a.standardDeviation() - return self.takeHightest(chunks, comparator) - - def logFreqFilter(self,chunks): - def comparator(a,b): - return a.wordFrequency() - b.wordFrequency() - return self.takeHightest(chunks, comparator) - - + def comparator(a, b): + return a.totalWordLength() - b.totalWordLength() + + return self.takeHightest(chunks, comparator) + + def lawlFilter(self, chunks): + def comparator(a, b): + return a.averageWordLength() - b.averageWordLength() + + return self.takeHightest(chunks, comparator) + + def svmlFilter(self, chunks): + def comparator(a, b): + return b.standardDeviation() - a.standardDeviation() + + return self.takeHightest(chunks, comparator) + + def logFreqFilter(self, chunks): + def comparator(a, b): + return a.wordFrequency() - b.wordFrequency() + + return self.takeHightest(chunks, comparator) + + #加载词组字典和字符字典 dictWord = {} maxWordLength = 0 - -def loadDictChars(filepath): - global maxWordLength + + +def loadDictChars(filepath): + global maxWordLength fsock = open(filepath) for line in fsock: freq, word = line.split() word = word.strip() - dictWord[word] = (len(word), int(freq)) - maxWordLength = len(word) if maxWordLength < len(word) else maxWordLength - fsock.close() - -def loadDictWords(filepath): - global maxWordLength - fsock = open(filepath) - for line in fsock.readlines(): + dictWord[word] = (len(word), int(freq)) + maxWordLength = len(word) if maxWordLength < len( + word) else maxWordLength + fsock.close() + + +def loadDictWords(filepath): + global maxWordLength + fsock = open(filepath) + for line in fsock.readlines(): word = line.strip() - dictWord[word] = (len(word), 0) - maxWordLength = len(word) if maxWordLength < len(word) else maxWordLength - fsock.close() - + dictWord[word] = (len(word), 0) + maxWordLength = len(word) if maxWordLength < len( + word) else maxWordLength + fsock.close() + + #判断该词word是否在字典dictWord中 -def getDictWord(word): - result = dictWord.get(word) - if result: - return Word(word, result[1]) - return None - +def getDictWord(word): + result = dictWord.get(word) + if result: + return Word(word, result[1]) + return None + + #开始加载字典 -def run(): - from os.path import join, dirname - loadDictChars(join(dirname(__file__), 'data', 'chars.dic')) - loadDictWords(join(dirname(__file__), 'data', 'words.dic')) - -class Analysis: - - def __init__(self, text): +def run(): + from os.path import join, dirname + loadDictChars(join(dirname(__file__), 'data', 'chars.dic')) + loadDictWords(join(dirname(__file__), 'data', 'words.dic')) + + +class Analysis: + def __init__(self, text): self.text = text - self.cacheSize = 3 - self.pos = 0 - self.textLength = len(self.text) - self.cache = [] - self.cacheIndex = 0 - self.complexCompare = ComplexCompare() - + self.cacheSize = 3 + self.pos = 0 + self.textLength = len(self.text) + self.cache = [] + self.cacheIndex = 0 + self.complexCompare = ComplexCompare() + #简单小技巧,用到个缓存,不知道具体有没有用处 - for i in range(self.cacheSize): + for i in range(self.cacheSize): self.cache.append([-1, Word()]) - + #控制字典只加载一次 if not dictWord: run() + def __iter__(self): + while True: + token = self.getNextToken() + if token is None: + raise StopIteration + yield token + + def getNextChar(self): + return self.text[self.pos] - def __iter__(self): - while True: - token = self.getNextToken() - if token == None: - raise StopIteration - yield token - - def getNextChar(self): - return self.text[self.pos] - #判断该字符是否是中文字符(不包括中文标点) - def isChineseChar(self,charater): - return 0x4e00 <= ord(charater) < 0x9fa6 - + def isChineseChar(self, charater): + return 0x4e00 <= ord(charater) < 0x9fa6 + #判断是否是ASCII码 - def isASCIIChar(self, ch): - import string - if ch in string.whitespace: - return False - if ch in string.punctuation: - return False + def isASCIIChar(self, ch): + import string + if ch in string.whitespace: + return False + if ch in string.punctuation: + return False return ch in string.printable - + #得到下一个切割结果 - def getNextToken(self): - while self.pos < self.textLength: - if self.isChineseChar(self.getNextChar()): - token = self.getChineseWords() - else : - token = self.getASCIIWords()+'/' - if len(token) > 0: + def getNextToken(self): + while self.pos < self.textLength: + if self.isChineseChar(self.getNextChar()): + token = self.getChineseWords() + else: + token = self.getASCIIWords() + '/' + if len(token) > 0: return token return None - + #切割出非中文词 def getASCIIWords(self): # Skip pre-word whitespaces and punctuations #跳过中英文标点和空格 - while self.pos < self.textLength: - ch = self.getNextChar() - if self.isASCIIChar(ch) or self.isChineseChar(ch): - break - self.pos += 1 + while self.pos < self.textLength: + ch = self.getNextChar() + if self.isASCIIChar(ch) or self.isChineseChar(ch): + break + self.pos += 1 #得到英文单词的起始位置 - start = self.pos - + start = self.pos + #找出英文单词的结束位置 - while self.pos < self.textLength: - ch = self.getNextChar() - if not self.isASCIIChar(ch): - break - self.pos += 1 - end = self.pos - + while self.pos < self.textLength: + ch = self.getNextChar() + if not self.isASCIIChar(ch): + break + self.pos += 1 + end = self.pos + #Skip chinese word whitespaces and punctuations #跳过中英文标点和空格 - while self.pos < self.textLength: - ch = self.getNextChar() - if self.isASCIIChar(ch) or self.isChineseChar(ch): - break - self.pos += 1 - + while self.pos < self.textLength: + ch = self.getNextChar() + if self.isASCIIChar(ch) or self.isChineseChar(ch): + break + self.pos += 1 + #返回英文单词 - return self.text[start:end] - + return self.text[start:end] + #切割出中文词,并且做处理,用上述4种方法 - def getChineseWords(self): - chunks = self.createChunks() - if len(chunks) > 1: - chunks = self.complexCompare.mmFilter(chunks) - if len(chunks) > 1: - chunks = self.complexCompare.lawlFilter(chunks) - if len(chunks) > 1: - chunks = self.complexCompare.svmlFilter(chunks) - if len(chunks) > 1: - chunks = self.complexCompare.logFreqFilter(chunks) - if len(chunks) == 0 : - return '' + def getChineseWords(self): + chunks = self.createChunks() + if len(chunks) > 1: + chunks = self.complexCompare.mmFilter(chunks) + if len(chunks) > 1: + chunks = self.complexCompare.lawlFilter(chunks) + if len(chunks) > 1: + chunks = self.complexCompare.svmlFilter(chunks) + if len(chunks) > 1: + chunks = self.complexCompare.logFreqFilter(chunks) + if len(chunks) == 0: + return '' #最后只有一种切割方法 - word = chunks[0].words - token = "" - length = 0 - for x in word: - if x.length != -1: - token += x.text + "/" - length += len(x.text) - self.pos += length - return token - + word = chunks[0].words + token = "" + length = 0 + for x in word: + if x.length != -1: + token += x.text + "/" + length += len(x.text) + self.pos += length + return token + #三重循环来枚举切割方法,这里也可以运用递归来实现 - def createChunks(self): - chunks = [] - originalPos = self.pos - words1 = self.getMatchChineseWords() - - for word1 in words1: - self.pos += len(word1.text) - if self.pos < self.textLength: - words2 = self.getMatchChineseWords() - for word2 in words2: - self.pos += len(word2.text) - if self.pos < self.textLength: - words3 = self.getMatchChineseWords() - for word3 in words3: + def createChunks(self): + chunks = [] + originalPos = self.pos + words1 = self.getMatchChineseWords() + + for word1 in words1: + self.pos += len(word1.text) + if self.pos < self.textLength: + words2 = self.getMatchChineseWords() + for word2 in words2: + self.pos += len(word2.text) + if self.pos < self.textLength: + words3 = self.getMatchChineseWords() + for word3 in words3: # print(word3.length, word3.text) - if word3.length == -1: - chunk = Chunk(word1,word2) + if word3.length == -1: + chunk = Chunk(word1, word2) # print("Ture") - else : - chunk = Chunk(word1,word2,word3) - chunks.append(chunk) - elif self.pos == self.textLength: - chunks.append(Chunk(word1,word2)) - self.pos -= len(word2.text) - elif self.pos == self.textLength: - chunks.append(Chunk(word1)) - self.pos -= len(word1.text) - - self.pos = originalPos - return chunks - + else: + chunk = Chunk(word1, word2, word3) + chunks.append(chunk) + elif self.pos == self.textLength: + chunks.append(Chunk(word1, word2)) + self.pos -= len(word2.text) + elif self.pos == self.textLength: + chunks.append(Chunk(word1)) + self.pos -= len(word1.text) + + self.pos = originalPos + return chunks + #运用正向最大匹配算法结合字典来切割中文文本 - def getMatchChineseWords(self): + def getMatchChineseWords(self): #use cache,check it - for i in range(self.cacheSize): - if self.cache[i][0] == self.pos: - return self.cache[i][1] - - originalPos = self.pos - words = [] - index = 0 - while self.pos < self.textLength: - if index >= maxWordLength : - break - if not self.isChineseChar(self.getNextChar()): - break - self.pos += 1 - index += 1 - - text = self.text[originalPos:self.pos] - word = getDictWord(text) - if word: - words.append(word) - - self.pos = originalPos + for i in range(self.cacheSize): + if self.cache[i][0] == self.pos: + return self.cache[i][1] + + originalPos = self.pos + words = [] + index = 0 + while self.pos < self.textLength: + if index >= maxWordLength: + break + if not self.isChineseChar(self.getNextChar()): + break + self.pos += 1 + index += 1 + + text = self.text[originalPos:self.pos] + word = getDictWord(text) + if word: + words.append(word) + + self.pos = originalPos #没有词则放置个‘X’,将文本长度标记为-1 - if not words: - word = Word() - word.length = -1 - word.text = 'X' - words.append(word) - - self.cache[self.cacheIndex] = (self.pos,words) - self.cacheIndex += 1 - if self.cacheIndex >= self.cacheSize: - self.cacheIndex = 0 - return words - - -if __name__=="__main__": - - def cuttest(text): + if not words: + word = Word() + word.length = -1 + word.text = 'X' + words.append(word) + + self.cache[self.cacheIndex] = (self.pos, words) + self.cacheIndex += 1 + if self.cacheIndex >= self.cacheSize: + self.cacheIndex = 0 + return words + + +if __name__ == "__main__": + + def cuttest(text): #cut = Analysis(text) - tmp="" + tmp = "" try: for word in iter(Analysis(text)): tmp += word @@ -310,71 +320,73 @@ if __name__=="__main__": print("================================") cuttest(u"研究生命来源") - cuttest(u"南京市长江大桥欢迎您") - cuttest(u"请把手抬高一点儿") - cuttest(u"长春市长春节致词。") - cuttest(u"长春市长春药店。") - cuttest(u"我的和服务必在明天做好。") - cuttest(u"我发现有很多人喜欢他。") - cuttest(u"我喜欢看电视剧大长今。") - cuttest(u"半夜给拎起来陪看欧洲杯糊着两眼半晌没搞明白谁和谁踢。") - cuttest(u"李智伟高高兴兴以及王晓薇出去玩,后来智伟和晓薇又单独去玩了。") - cuttest(u"一次性交出去很多钱。 ") - cuttest(u"这是一个伸手不见五指的黑夜。我叫孙悟空,我爱北京,我爱Python和C++。") - cuttest(u"我不喜欢日本和服。") - cuttest(u"雷猴回归人间。") - cuttest(u"工信处女干事每月经过下属科室都要亲口交代24口交换机等技术性器件的安装工作") - cuttest(u"我需要廉租房") - cuttest(u"永和服装饰品有限公司") - cuttest(u"我爱北京天安门") - cuttest(u"abc") - cuttest(u"隐马尔可夫") - cuttest(u"雷猴是个好网站") - cuttest(u"“Microsoft”一词由“MICROcomputer(微型计算机)”和“SOFTware(软件)”两部分组成") - cuttest(u"草泥马和欺实马是今年的流行词汇") - cuttest(u"伊藤洋华堂总府店") - cuttest(u"中国科学院计算技术研究所") - cuttest(u"罗密欧与朱丽叶") - cuttest(u"我购买了道具和服装") - cuttest(u"PS: 我觉得开源有一个好处,就是能够敦促自己不断改进,避免敞帚自珍") - cuttest(u"湖北省石首市") - cuttest(u"总经理完成了这件事情") - cuttest(u"电脑修好了") - cuttest(u"做好了这件事情就一了百了了") - cuttest(u"人们审美的观点是不同的") - cuttest(u"我们买了一个美的空调") - cuttest(u"线程初始化时我们要注意") - cuttest(u"一个分子是由好多原子组织成的") - cuttest(u"祝你马到功成") - cuttest(u"他掉进了无底洞里") - cuttest(u"中国的首都是北京") - cuttest(u"孙君意") - cuttest(u"外交部发言人马朝旭") - cuttest(u"领导人会议和第四届东亚峰会") - cuttest(u"在过去的这五年") - cuttest(u"还需要很长的路要走") - cuttest(u"60周年首都阅兵") - cuttest(u"你好人们审美的观点是不同的") - cuttest(u"买水果然后来世博园") - cuttest(u"买水果然后去世博园") - cuttest(u"但是后来我才知道你是对的") - cuttest(u"存在即合理") - cuttest(u"的的的的的在的的的的就以和和和") - cuttest(u"I love你,不以为耻,反以为rong") - cuttest(u" ") - cuttest(u"") - cuttest(u"hello你好人们审美的观点是不同的") - cuttest(u"很好但主要是基于网页形式") - cuttest(u"hello你好人们审美的观点是不同的") - cuttest(u"为什么我不能拥有想要的生活") - cuttest(u"后来我才") - cuttest(u"此次来中国是为了") - cuttest(u"使用了它就可以解决一些问题") - cuttest(u",使用了它就可以解决一些问题") - cuttest(u"其实使用了它就可以解决一些问题") - cuttest(u"好人使用了它就可以解决一些问题") - cuttest(u"是因为和国家") - cuttest(u"老年搜索还支持") - cuttest(u"干脆就把那部蒙人的闲法给废了拉倒!RT @laoshipukong : 27日,全国人大常委会第三次审议侵权责任法草案,删除了有关医疗损害责任“举证倒置”的规定。在医患纠纷中本已处于弱势地位的消费者由此将陷入万劫不复的境地。 ") + cuttest(u"南京市长江大桥欢迎您") + cuttest(u"请把手抬高一点儿") + cuttest(u"长春市长春节致词。") + cuttest(u"长春市长春药店。") + cuttest(u"我的和服务必在明天做好。") + cuttest(u"我发现有很多人喜欢他。") + cuttest(u"我喜欢看电视剧大长今。") + cuttest(u"半夜给拎起来陪看欧洲杯糊着两眼半晌没搞明白谁和谁踢。") + cuttest(u"李智伟高高兴兴以及王晓薇出去玩,后来智伟和晓薇又单独去玩了。") + cuttest(u"一次性交出去很多钱。 ") + cuttest(u"这是一个伸手不见五指的黑夜。我叫孙悟空,我爱北京,我爱Python和C++。") + cuttest(u"我不喜欢日本和服。") + cuttest(u"雷猴回归人间。") + cuttest(u"工信处女干事每月经过下属科室都要亲口交代24口交换机等技术性器件的安装工作") + cuttest(u"我需要廉租房") + cuttest(u"永和服装饰品有限公司") + cuttest(u"我爱北京天安门") + cuttest(u"abc") + cuttest(u"隐马尔可夫") + cuttest(u"雷猴是个好网站") + cuttest(u"“Microsoft”一词由“MICROcomputer(微型计算机)”和“SOFTware(软件)”两部分组成") + cuttest(u"草泥马和欺实马是今年的流行词汇") + cuttest(u"伊藤洋华堂总府店") + cuttest(u"中国科学院计算技术研究所") + cuttest(u"罗密欧与朱丽叶") + cuttest(u"我购买了道具和服装") + cuttest(u"PS: 我觉得开源有一个好处,就是能够敦促自己不断改进,避免敞帚自珍") + cuttest(u"湖北省石首市") + cuttest(u"总经理完成了这件事情") + cuttest(u"电脑修好了") + cuttest(u"做好了这件事情就一了百了了") + cuttest(u"人们审美的观点是不同的") + cuttest(u"我们买了一个美的空调") + cuttest(u"线程初始化时我们要注意") + cuttest(u"一个分子是由好多原子组织成的") + cuttest(u"祝你马到功成") + cuttest(u"他掉进了无底洞里") + cuttest(u"中国的首都是北京") + cuttest(u"孙君意") + cuttest(u"外交部发言人马朝旭") + cuttest(u"领导人会议和第四届东亚峰会") + cuttest(u"在过去的这五年") + cuttest(u"还需要很长的路要走") + cuttest(u"60周年首都阅兵") + cuttest(u"你好人们审美的观点是不同的") + cuttest(u"买水果然后来世博园") + cuttest(u"买水果然后去世博园") + cuttest(u"但是后来我才知道你是对的") + cuttest(u"存在即合理") + cuttest(u"的的的的的在的的的的就以和和和") + cuttest(u"I love你,不以为耻,反以为rong") + cuttest(u" ") + cuttest(u"") + cuttest(u"hello你好人们审美的观点是不同的") + cuttest(u"很好但主要是基于网页形式") + cuttest(u"hello你好人们审美的观点是不同的") + cuttest(u"为什么我不能拥有想要的生活") + cuttest(u"后来我才") + cuttest(u"此次来中国是为了") + cuttest(u"使用了它就可以解决一些问题") + cuttest(u",使用了它就可以解决一些问题") + cuttest(u"其实使用了它就可以解决一些问题") + cuttest(u"好人使用了它就可以解决一些问题") + cuttest(u"是因为和国家") + cuttest(u"老年搜索还支持") + cuttest( + u"干脆就把那部蒙人的闲法给废了拉倒!RT @laoshipukong : 27日,全国人大常委会第三次审议侵权责任法草案,删除了有关医疗损害责任“举证倒置”的规定。在医患纠纷中本已处于弱势地位的消费者由此将陷入万劫不复的境地。 " + ) cuttest("2022年12月30日是星期几?") - cuttest("二零二二年十二月三十日是星期几?") \ No newline at end of file + cuttest("二零二二年十二月三十日是星期几?") diff --git a/speechx/examples/wfst/README.md b/speechx/examples/wfst/README.md index 4f4674a4..d0bdac0f 100644 --- a/speechx/examples/wfst/README.md +++ b/speechx/examples/wfst/README.md @@ -183,4 +183,4 @@ data/ ├── lexiconp_disambig.txt ├── lexiconp.txt └── units.list -``` \ No newline at end of file +``` diff --git a/utils/DER.py b/utils/DER.py index d6ab695d..59bcbec4 100755 --- a/utils/DER.py +++ b/utils/DER.py @@ -26,9 +26,9 @@ import argparse import os import re import subprocess -from distutils.util import strtobool import numpy as np +from distutils.util import strtobool FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)") SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+") diff --git a/utils/compute-wer.py b/utils/compute-wer.py index b3dbf225..978a80c9 100755 --- a/utils/compute-wer.py +++ b/utils/compute-wer.py @@ -1,61 +1,66 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- # CopyRight WeNet Apache-2.0 License - -import re, sys, unicodedata import codecs +import re +import sys +import unicodedata remove_tag = True -spacelist= [' ', '\t', '\r', '\n'] -puncts = ['!', ',', '?', - '、', '。', '!', ',', ';', '?', - ':', '「', '」', '︰', '『', '』', '《', '》'] +spacelist = [' ', '\t', '\r', '\n'] +puncts = [ + '!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』', + '《', '》' +] + + +def characterize(string): + res = [] + i = 0 + while i < len(string): + char = string[i] + if char in puncts: + i += 1 + continue + cat1 = unicodedata.category(char) + #https://unicodebook.readthedocs.io/unicode.html#unicode-categories + if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned + i += 1 + continue + if cat1 == 'Lo': # letter-other + res.append(char) + i += 1 + else: + # some input looks like: , we want to separate it to two words. + sep = ' ' + if char == '<': sep = '>' + j = i + 1 + while j < len(string): + c = string[j] + if ord(c) >= 128 or (c in spacelist) or (c == sep): + break + j += 1 + if j < len(string) and string[j] == '>': + j += 1 + res.append(string[i:j]) + i = j + return res -def characterize(string) : - res = [] - i = 0 - while i < len(string): - char = string[i] - if char in puncts: - i += 1 - continue - cat1 = unicodedata.category(char) - #https://unicodebook.readthedocs.io/unicode.html#unicode-categories - if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned - i += 1 - continue - if cat1 == 'Lo': # letter-other - res.append(char) - i += 1 - else: - # some input looks like: , we want to separate it to two words. - sep = ' ' - if char == '<': sep = '>' - j = i+1 - while j < len(string): - c = string[j] - if ord(c) >= 128 or (c in spacelist) or (c==sep): - break - j += 1 - if j < len(string) and string[j] == '>': - j += 1 - res.append(string[i:j]) - i = j - return res def stripoff_tags(x): - if not x: return '' - chars = [] - i = 0; T=len(x) - while i < T: - if x[i] == '<': - while i < T and x[i] != '>': - i += 1 - i += 1 - else: - chars.append(x[i]) - i += 1 - return ''.join(chars) + if not x: return '' + chars = [] + i = 0 + T = len(x) + while i < T: + if x[i] == '<': + while i < T and x[i] != '>': + i += 1 + i += 1 + else: + chars.append(x[i]) + i += 1 + return ''.join(chars) def normalize(sentence, ignore_words, cs, split=None): @@ -65,436 +70,485 @@ def normalize(sentence, ignore_words, cs, split=None): for token in sentence: x = token if not cs: - x = x.upper() + x = x.upper() if x in ignore_words: - continue + continue if remove_tag: - x = stripoff_tags(x) + x = stripoff_tags(x) if not x: - continue + continue if split and x in split: - new_sentence += split[x] + new_sentence += split[x] else: - new_sentence.append(x) + new_sentence.append(x) return new_sentence -class Calculator : - def __init__(self) : - self.data = {} - self.space = [] - self.cost = {} - self.cost['cor'] = 0 - self.cost['sub'] = 1 - self.cost['del'] = 1 - self.cost['ins'] = 1 - def calculate(self, lab, rec) : - # Initialization - lab.insert(0, '') - rec.insert(0, '') - while len(self.space) < len(lab) : - self.space.append([]) - for row in self.space : - for element in row : - element['dist'] = 0 - element['error'] = 'non' - while len(row) < len(rec) : - row.append({'dist' : 0, 'error' : 'non'}) - for i in range(len(lab)) : - self.space[i][0]['dist'] = i - self.space[i][0]['error'] = 'del' - for j in range(len(rec)) : - self.space[0][j]['dist'] = j - self.space[0][j]['error'] = 'ins' - self.space[0][0]['error'] = 'non' - for token in lab : - if token not in self.data and len(token) > 0 : - self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} - for token in rec : - if token not in self.data and len(token) > 0 : - self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0} - # Computing edit distance - for i, lab_token in enumerate(lab) : - for j, rec_token in enumerate(rec) : - if i == 0 or j == 0 : - continue - min_dist = sys.maxsize - min_error = 'none' - dist = self.space[i-1][j]['dist'] + self.cost['del'] - error = 'del' - if dist < min_dist : - min_dist = dist - min_error = error - dist = self.space[i][j-1]['dist'] + self.cost['ins'] - error = 'ins' - if dist < min_dist : - min_dist = dist - min_error = error - if lab_token == rec_token : - dist = self.space[i-1][j-1]['dist'] + self.cost['cor'] - error = 'cor' - else : - dist = self.space[i-1][j-1]['dist'] + self.cost['sub'] - error = 'sub' - if dist < min_dist : - min_dist = dist - min_error = error - self.space[i][j]['dist'] = min_dist - self.space[i][j]['error'] = min_error - # Tracing back - result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} - i = len(lab) - 1 - j = len(rec) - 1 - while True : - if self.space[i][j]['error'] == 'cor' : # correct - if len(lab[i]) > 0 : - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 - result['all'] = result['all'] + 1 - result['cor'] = result['cor'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'sub' : # substitution - if len(lab[i]) > 0 : - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 - result['all'] = result['all'] + 1 - result['sub'] = result['sub'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, rec[j]) - i = i - 1 - j = j - 1 - elif self.space[i][j]['error'] == 'del' : # deletion - if len(lab[i]) > 0 : - self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 - self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 - result['all'] = result['all'] + 1 - result['del'] = result['del'] + 1 - result['lab'].insert(0, lab[i]) - result['rec'].insert(0, "") - i = i - 1 - elif self.space[i][j]['error'] == 'ins' : # insertion - if len(rec[j]) > 0 : - self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 - result['ins'] = result['ins'] + 1 - result['lab'].insert(0, "") - result['rec'].insert(0, rec[j]) - j = j - 1 - elif self.space[i][j]['error'] == 'non' : # starting point - break - else : # shouldn't reach here - print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error'])) - return result - def overall(self) : - result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} - for token in self.data : - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - def cluster(self, data) : - result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0} - for token in data : - if token in self.data : - result['all'] = result['all'] + self.data[token]['all'] - result['cor'] = result['cor'] + self.data[token]['cor'] - result['sub'] = result['sub'] + self.data[token]['sub'] - result['ins'] = result['ins'] + self.data[token]['ins'] - result['del'] = result['del'] + self.data[token]['del'] - return result - def keys(self) : - return list(self.data.keys()) + +class Calculator: + def __init__(self): + self.data = {} + self.space = [] + self.cost = {} + self.cost['cor'] = 0 + self.cost['sub'] = 1 + self.cost['del'] = 1 + self.cost['ins'] = 1 + + def calculate(self, lab, rec): + # Initialization + lab.insert(0, '') + rec.insert(0, '') + while len(self.space) < len(lab): + self.space.append([]) + for row in self.space: + for element in row: + element['dist'] = 0 + element['error'] = 'non' + while len(row) < len(rec): + row.append({'dist': 0, 'error': 'non'}) + for i in range(len(lab)): + self.space[i][0]['dist'] = i + self.space[i][0]['error'] = 'del' + for j in range(len(rec)): + self.space[0][j]['dist'] = j + self.space[0][j]['error'] = 'ins' + self.space[0][0]['error'] = 'non' + for token in lab: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + for token in rec: + if token not in self.data and len(token) > 0: + self.data[token] = { + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + # Computing edit distance + for i, lab_token in enumerate(lab): + for j, rec_token in enumerate(rec): + if i == 0 or j == 0: + continue + min_dist = sys.maxsize + min_error = 'none' + dist = self.space[i - 1][j]['dist'] + self.cost['del'] + error = 'del' + if dist < min_dist: + min_dist = dist + min_error = error + dist = self.space[i][j - 1]['dist'] + self.cost['ins'] + error = 'ins' + if dist < min_dist: + min_dist = dist + min_error = error + if lab_token == rec_token: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor'] + error = 'cor' + else: + dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub'] + error = 'sub' + if dist < min_dist: + min_dist = dist + min_error = error + self.space[i][j]['dist'] = min_dist + self.space[i][j]['error'] = min_error + # Tracing back + result = { + 'lab': [], + 'rec': [], + 'all': 0, + 'cor': 0, + 'sub': 0, + 'ins': 0, + 'del': 0 + } + i = len(lab) - 1 + j = len(rec) - 1 + while True: + if self.space[i][j]['error'] == 'cor': # correct + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1 + result['all'] = result['all'] + 1 + result['cor'] = result['cor'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'sub': # substitution + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1 + result['all'] = result['all'] + 1 + result['sub'] = result['sub'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, rec[j]) + i = i - 1 + j = j - 1 + elif self.space[i][j]['error'] == 'del': # deletion + if len(lab[i]) > 0: + self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1 + self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1 + result['all'] = result['all'] + 1 + result['del'] = result['del'] + 1 + result['lab'].insert(0, lab[i]) + result['rec'].insert(0, "") + i = i - 1 + elif self.space[i][j]['error'] == 'ins': # insertion + if len(rec[j]) > 0: + self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1 + result['ins'] = result['ins'] + 1 + result['lab'].insert(0, "") + result['rec'].insert(0, rec[j]) + j = j - 1 + elif self.space[i][j]['error'] == 'non': # starting point + break + else: # shouldn't reach here + print( + 'this should not happen , i = {i} , j = {j} , error = {error}'. + format(i=i, j=j, error=self.space[i][j]['error'])) + return result + + def overall(self): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def cluster(self, data): + result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0} + for token in data: + if token in self.data: + result['all'] = result['all'] + self.data[token]['all'] + result['cor'] = result['cor'] + self.data[token]['cor'] + result['sub'] = result['sub'] + self.data[token]['sub'] + result['ins'] = result['ins'] + self.data[token]['ins'] + result['del'] = result['del'] + self.data[token]['del'] + return result + + def keys(self): + return list(self.data.keys()) + def width(string): - return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) + return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) -def default_cluster(word) : - unicode_names = [ unicodedata.name(char) for char in word ] - for i in reversed(range(len(unicode_names))) : - if unicode_names[i].startswith('DIGIT') : # 1 - unicode_names[i] = 'Number' # 'DIGIT' - elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or - unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) : - # 明 / 郎 - unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' - elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or - unicode_names[i].startswith('LATIN SMALL LETTER')) : - # A / a - unicode_names[i] = 'English' # 'LATIN LETTER' - elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め - unicode_names[i] = 'Japanese' # 'GANA LETTER' - elif (unicode_names[i].startswith('AMPERSAND') or - unicode_names[i].startswith('APOSTROPHE') or - unicode_names[i].startswith('COMMERCIAL AT') or - unicode_names[i].startswith('DEGREE CELSIUS') or - unicode_names[i].startswith('EQUALS SIGN') or - unicode_names[i].startswith('FULL STOP') or - unicode_names[i].startswith('HYPHEN-MINUS') or - unicode_names[i].startswith('LOW LINE') or - unicode_names[i].startswith('NUMBER SIGN') or - unicode_names[i].startswith('PLUS SIGN') or - unicode_names[i].startswith('SEMICOLON')) : - # & / ' / @ / ℃ / = / . / - / _ / # / + / ; - del unicode_names[i] - else : - return 'Other' - if len(unicode_names) == 0 : - return 'Other' - if len(unicode_names) == 1 : - return unicode_names[0] - for i in range(len(unicode_names)-1) : - if unicode_names[i] != unicode_names[i+1] : - return 'Other' - return unicode_names[0] -def usage() : - print("compute-wer.py : compute word error rate (WER) and align recognition results and references.") - print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer") +def default_cluster(word): + unicode_names = [unicodedata.name(char) for char in word] + for i in reversed(range(len(unicode_names))): + if unicode_names[i].startswith('DIGIT'): # 1 + unicode_names[i] = 'Number' # 'DIGIT' + elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or + unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')): + # 明 / 郎 + unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH' + elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or + unicode_names[i].startswith('LATIN SMALL LETTER')): + # A / a + unicode_names[i] = 'English' # 'LATIN LETTER' + elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め + unicode_names[i] = 'Japanese' # 'GANA LETTER' + elif (unicode_names[i].startswith('AMPERSAND') or + unicode_names[i].startswith('APOSTROPHE') or + unicode_names[i].startswith('COMMERCIAL AT') or + unicode_names[i].startswith('DEGREE CELSIUS') or + unicode_names[i].startswith('EQUALS SIGN') or + unicode_names[i].startswith('FULL STOP') or + unicode_names[i].startswith('HYPHEN-MINUS') or + unicode_names[i].startswith('LOW LINE') or + unicode_names[i].startswith('NUMBER SIGN') or + unicode_names[i].startswith('PLUS SIGN') or + unicode_names[i].startswith('SEMICOLON')): + # & / ' / @ / ℃ / = / . / - / _ / # / + / ; + del unicode_names[i] + else: + return 'Other' + if len(unicode_names) == 0: + return 'Other' + if len(unicode_names) == 1: + return unicode_names[0] + for i in range(len(unicode_names) - 1): + if unicode_names[i] != unicode_names[i + 1]: + return 'Other' + return unicode_names[0] + + +def usage(): + print( + "compute-wer.py : compute word error rate (WER) and align recognition results and references." + ) + print( + " usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer" + ) + if __name__ == '__main__': - if len(sys.argv) == 1 : - usage() - sys.exit(0) - calculator = Calculator() - cluster_file = '' - ignore_words = set() - tochar = False - verbose= 1 - padding_symbol= ' ' - case_sensitive = False - max_words_per_line = sys.maxsize - split = None - while len(sys.argv) > 3: - a = '--maxw=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):] - del sys.argv[1] - max_words_per_line = int(b) - continue - a = '--rt=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - remove_tag = (b == 'true') or (b != '0') - continue - a = '--cs=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - case_sensitive = (b == 'true') or (b != '0') - continue - a = '--cluster=' - if sys.argv[1].startswith(a): - cluster_file = sys.argv[1][len(a):] - del sys.argv[1] - continue - a = '--splitfile=' - if sys.argv[1].startswith(a): - split_file = sys.argv[1][len(a):] - del sys.argv[1] - split = dict() - with codecs.open(split_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - words = line.strip().split() - if len(words) >= 2: - split[words[0]] = words[1:] - continue - a = '--ig=' - if sys.argv[1].startswith(a): - ignore_file = sys.argv[1][len(a):] - del sys.argv[1] - with codecs.open(ignore_file, 'r', 'utf-8') as fh: - for line in fh: # line in unicode - line = line.strip() - if len(line) > 0: - ignore_words.add(line) - continue - a = '--char=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - tochar = (b == 'true') or (b != '0') - continue - a = '--v=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - verbose=0 - try: - verbose=int(b) - except: - if b == 'true' or b != '0': - verbose = 1 - continue - a = '--padding-symbol=' - if sys.argv[1].startswith(a): - b = sys.argv[1][len(a):].lower() - del sys.argv[1] - if b == 'space': - padding_symbol= ' ' - elif b == 'underline': - padding_symbol= '_' - continue - if True or sys.argv[1].startswith('-'): - #ignore invalid switch - del sys.argv[1] - continue + if len(sys.argv) == 1: + usage() + sys.exit(0) + calculator = Calculator() + cluster_file = '' + ignore_words = set() + tochar = False + verbose = 1 + padding_symbol = ' ' + case_sensitive = False + max_words_per_line = sys.maxsize + split = None + while len(sys.argv) > 3: + a = '--maxw=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):] + del sys.argv[1] + max_words_per_line = int(b) + continue + a = '--rt=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + remove_tag = (b == 'true') or (b != '0') + continue + a = '--cs=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + case_sensitive = (b == 'true') or (b != '0') + continue + a = '--cluster=' + if sys.argv[1].startswith(a): + cluster_file = sys.argv[1][len(a):] + del sys.argv[1] + continue + a = '--splitfile=' + if sys.argv[1].startswith(a): + split_file = sys.argv[1][len(a):] + del sys.argv[1] + split = dict() + with codecs.open(split_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + words = line.strip().split() + if len(words) >= 2: + split[words[0]] = words[1:] + continue + a = '--ig=' + if sys.argv[1].startswith(a): + ignore_file = sys.argv[1][len(a):] + del sys.argv[1] + with codecs.open(ignore_file, 'r', 'utf-8') as fh: + for line in fh: # line in unicode + line = line.strip() + if len(line) > 0: + ignore_words.add(line) + continue + a = '--char=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + tochar = (b == 'true') or (b != '0') + continue + a = '--v=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + verbose = 0 + try: + verbose = int(b) + except: + if b == 'true' or b != '0': + verbose = 1 + continue + a = '--padding-symbol=' + if sys.argv[1].startswith(a): + b = sys.argv[1][len(a):].lower() + del sys.argv[1] + if b == 'space': + padding_symbol = ' ' + elif b == 'underline': + padding_symbol = '_' + continue + if True or sys.argv[1].startswith('-'): + #ignore invalid switch + del sys.argv[1] + continue - if not case_sensitive: - ig=set([w.upper() for w in ignore_words]) - ignore_words = ig + if not case_sensitive: + ig = set([w.upper() for w in ignore_words]) + ignore_words = ig - default_clusters = {} - default_words = {} + default_clusters = {} + default_words = {} - ref_file = sys.argv[1] - hyp_file = sys.argv[2] - rec_set = {} - if split and not case_sensitive: - newsplit = dict() - for w in split: - words = split[w] - for i in range(len(words)): - words[i] = words[i].upper() - newsplit[w.upper()] = words - split = newsplit + ref_file = sys.argv[1] + hyp_file = sys.argv[2] + rec_set = {} + if split and not case_sensitive: + newsplit = dict() + for w in split: + words = split[w] + for i in range(len(words)): + words[i] = words[i].upper() + newsplit[w.upper()] = words + split = newsplit - with codecs.open(hyp_file, 'r', 'utf-8') as fh: - for line in fh: + with codecs.open(hyp_file, 'r', 'utf-8') as fh: + for line in fh: + if tochar: + array = characterize(line) + else: + array = line.strip().split() + if len(array) == 0: continue + fid = array[0] + rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, + split) + + # compute error rate on the interaction of reference file and hyp file + for line in open(ref_file, 'r', encoding='utf-8'): if tochar: array = characterize(line) else: - array = line.strip().split() - if len(array)==0: continue + array = line.rstrip('\n').split() + if len(array) == 0: continue fid = array[0] - rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split) + if fid not in rec_set: + continue + lab = normalize(array[1:], ignore_words, case_sensitive, split) + rec = rec_set[fid] + if verbose: + print('\nutt: %s' % fid) - # compute error rate on the interaction of reference file and hyp file - for line in open(ref_file, 'r', encoding='utf-8') : - if tochar: - array = characterize(line) - else: - array = line.rstrip('\n').split() - if len(array)==0: continue - fid = array[0] - if fid not in rec_set: - continue - lab = normalize(array[1:], ignore_words, case_sensitive, split) - rec = rec_set[fid] - if verbose: - print('\nutt: %s' % fid) + for word in rec + lab: + if word not in default_words: + default_cluster_name = default_cluster(word) + if default_cluster_name not in default_clusters: + default_clusters[default_cluster_name] = {} + if word not in default_clusters[default_cluster_name]: + default_clusters[default_cluster_name][word] = 1 + default_words[word] = default_cluster_name - for word in rec + lab : - if word not in default_words : - default_cluster_name = default_cluster(word) - if default_cluster_name not in default_clusters : - default_clusters[default_cluster_name] = {} - if word not in default_clusters[default_cluster_name] : - default_clusters[default_cluster_name][word] = 1 - default_words[word] = default_cluster_name + result = calculator.calculate(lab, rec) + if verbose: + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('WER: %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + space = {} + space['lab'] = [] + space['rec'] = [] + for idx in range(len(result['lab'])): + len_lab = width(result['lab'][idx]) + len_rec = width(result['rec'][idx]) + length = max(len_lab, len_rec) + space['lab'].append(length - len_lab) + space['rec'].append(length - len_rec) + upper_lab = len(result['lab']) + upper_rec = len(result['rec']) + lab1, rec1 = 0, 0 + while lab1 < upper_lab or rec1 < upper_rec: + if verbose > 1: + print('lab(%s):' % fid.encode('utf-8'), end=' ') + else: + print('lab:', end=' ') + lab2 = min(upper_lab, lab1 + max_words_per_line) + for idx in range(lab1, lab2): + token = result['lab'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['lab'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print() + if verbose > 1: + print('rec(%s):' % fid.encode('utf-8'), end=' ') + else: + print('rec:', end=' ') + rec2 = min(upper_rec, rec1 + max_words_per_line) + for idx in range(rec1, rec2): + token = result['rec'][idx] + print('{token}'.format(token=token), end='') + for n in range(space['rec'][idx]): + print(padding_symbol, end='') + print(' ', end='') + print('\n', end='\n') + lab1 = lab2 + rec1 = rec2 - result = calculator.calculate(lab, rec) if verbose: - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : - wer = 0.0 - print('WER: %4.2f %%' % wer, end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - space = {} - space['lab'] = [] - space['rec'] = [] - for idx in range(len(result['lab'])) : - len_lab = width(result['lab'][idx]) - len_rec = width(result['rec'][idx]) - length = max(len_lab, len_rec) - space['lab'].append(length-len_lab) - space['rec'].append(length-len_rec) - upper_lab = len(result['lab']) - upper_rec = len(result['rec']) - lab1, rec1 = 0, 0 - while lab1 < upper_lab or rec1 < upper_rec: - if verbose > 1: - print('lab(%s):' % fid.encode('utf-8'), end = ' ') - else: - print('lab:', end = ' ') - lab2 = min(upper_lab, lab1 + max_words_per_line) - for idx in range(lab1, lab2): - token = result['lab'][idx] - print('{token}'.format(token = token), end = '') - for n in range(space['lab'][idx]) : - print(padding_symbol, end = '') - print(' ',end='') - print() - if verbose > 1: - print('rec(%s):' % fid.encode('utf-8'), end = ' ') - else: - print('rec:', end = ' ') - rec2 = min(upper_rec, rec1 + max_words_per_line) - for idx in range(rec1, rec2): - token = result['rec'][idx] - print('{token}'.format(token = token), end = '') - for n in range(space['rec'][idx]) : - print(padding_symbol, end = '') - print(' ',end='') - print('\n', end='\n') - lab1 = lab2 - rec1 = rec2 - - if verbose: - print('===========================================================================') - print() - - result = calculator.overall() - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : - wer = 0.0 - print('Overall -> %4.2f %%' % wer, end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - if not verbose: - print() + print( + '===========================================================================' + ) + print() - if verbose: - for cluster_id in default_clusters : - result = calculator.cluster([ k for k in default_clusters[cluster_id] ]) - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : + result = calculator.overall() + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - if len(cluster_file) > 0 : # compute separated WERs for word clusters - cluster_id = '' - cluster = [] - for line in open(cluster_file, 'r', encoding='utf-8') : - for token in line.decode('utf-8').rstrip('\n').split() : - # end of cluster reached, like - if token[0:2] == '' and \ - token.lstrip('') == cluster_id : - result = calculator.cluster(cluster) - if result['all'] != 0 : - wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] - else : - wer = 0.0 - print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ') - print('N=%d C=%d S=%d D=%d I=%d' % - (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) - cluster_id = '' - cluster = [] - # begin of cluster reached, like - elif token[0] == '<' and token[len(token)-1] == '>' and \ - cluster_id == '' : - cluster_id = token.lstrip('<').rstrip('>') - cluster = [] - # general terms, like WEATHER / CAR / ... - else : - cluster.append(token) - print() - print('===========================================================================') \ No newline at end of file + print('Overall -> %4.2f %%' % wer, end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if not verbose: + print() + + if verbose: + for cluster_id in default_clusters: + result = calculator.cluster( + [k for k in default_clusters[cluster_id]]) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], result['del'], + result['ins'])) + if len(cluster_file) > 0: # compute separated WERs for word clusters + cluster_id = '' + cluster = [] + for line in open(cluster_file, 'r', encoding='utf-8'): + for token in line.decode('utf-8').rstrip('\n').split(): + # end of cluster reached, like + if token[0:2] == '' and \ + token.lstrip('') == cluster_id : + result = calculator.cluster(cluster) + if result['all'] != 0: + wer = float(result['ins'] + result['sub'] + result[ + 'del']) * 100.0 / result['all'] + else: + wer = 0.0 + print('%s -> %4.2f %%' % (cluster_id, wer), end=' ') + print('N=%d C=%d S=%d D=%d I=%d' % + (result['all'], result['cor'], result['sub'], + result['del'], result['ins'])) + cluster_id = '' + cluster = [] + # begin of cluster reached, like + elif token[0] == '<' and token[len(token)-1] == '>' and \ + cluster_id == '' : + cluster_id = token.lstrip('<').rstrip('>') + cluster = [] + # general terms, like WEATHER / CAR / ... + else: + cluster.append(token) + print() + print( + '===========================================================================' + ) diff --git a/utils/format_rsl.py b/utils/format_rsl.py index d5bc0017..1a714253 100644 --- a/utils/format_rsl.py +++ b/utils/format_rsl.py @@ -1,11 +1,21 @@ -import os +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import argparse import jsonlines -def trans_hyp(origin_hyp, - trans_hyp = None, - trans_hyp_sclite = None): +def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None): """ Args: origin_hyp: The input json file which contains the model output @@ -17,19 +27,18 @@ def trans_hyp(origin_hyp, with open(origin_hyp, "r+", encoding="utf8") as f: for item in jsonlines.Reader(f): input_dict[item["utt"]] = item["hyps"][0] - if trans_hyp is not None: + if trans_hyp is not None: with open(trans_hyp, "w+", encoding="utf8") as f: for key in input_dict.keys(): f.write(key + " " + input_dict[key] + "\n") - if trans_hyp_sclite is not None: + if trans_hyp_sclite is not None: with open(trans_hyp_sclite, "w+") as f: for key in input_dict.keys(): - line = input_dict[key] + "(" + key + ".wav" +")" + "\n" + line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" f.write(line) -def trans_ref(origin_ref, - trans_ref = None, - trans_ref_sclite = None): + +def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None): """ Args: origin_hyp: The input json file which contains the model output @@ -49,42 +58,48 @@ def trans_ref(origin_ref, if trans_ref_sclite is not None: with open(trans_ref_sclite, "w") as f: for key in input_dict.keys(): - line = input_dict[key] + "(" + key + ".wav" +")" + "\n" + line = input_dict[key] + "(" + key + ".wav" + ")" + "\n" f.write(line) - if __name__ == "__main__": - parser = argparse.ArgumentParser(prog='format hyp file for compute CER/WER', add_help=True) + parser = argparse.ArgumentParser( + prog='format hyp file for compute CER/WER', add_help=True) parser.add_argument( - '--origin_hyp', - type=str, - default = None, - help='origin hyp file') + '--origin_hyp', type=str, default=None, help='origin hyp file') parser.add_argument( - '--trans_hyp', type=str, default = None, help='hyp file for caculating CER/WER') + '--trans_hyp', + type=str, + default=None, + help='hyp file for caculating CER/WER') parser.add_argument( - '--trans_hyp_sclite', type=str, default = None, help='hyp file for caculating CER/WER by sclite') + '--trans_hyp_sclite', + type=str, + default=None, + help='hyp file for caculating CER/WER by sclite') parser.add_argument( - '--origin_ref', - type=str, - default = None, - help='origin ref file') + '--origin_ref', type=str, default=None, help='origin ref file') parser.add_argument( - '--trans_ref', type=str, default = None, help='ref file for caculating CER/WER') + '--trans_ref', + type=str, + default=None, + help='ref file for caculating CER/WER') parser.add_argument( - '--trans_ref_sclite', type=str, default = None, help='ref file for caculating CER/WER by sclite') + '--trans_ref_sclite', + type=str, + default=None, + help='ref file for caculating CER/WER by sclite') parser_args = parser.parse_args() if parser_args.origin_hyp is not None: trans_hyp( - origin_hyp = parser_args.origin_hyp, - trans_hyp = parser_args.trans_hyp, - trans_hyp_sclite = parser_args.trans_hyp_sclite, ) + origin_hyp=parser_args.origin_hyp, + trans_hyp=parser_args.trans_hyp, + trans_hyp_sclite=parser_args.trans_hyp_sclite, ) if parser_args.origin_ref is not None: trans_ref( - origin_ref = parser_args.origin_ref, - trans_ref = parser_args.trans_ref, - trans_ref_sclite = parser_args.trans_ref_sclite, ) + origin_ref=parser_args.origin_ref, + trans_ref=parser_args.trans_ref, + trans_ref_sclite=parser_args.trans_ref_sclite, ) diff --git a/utils/fst/prepare_dict.py b/utils/fst/prepare_dict.py index 301d72fb..e000856e 100755 --- a/utils/fst/prepare_dict.py +++ b/utils/fst/prepare_dict.py @@ -35,7 +35,7 @@ def main(args): # used to filter polyphone and invalid word lexicon_table = set() in_n = 0 # in lexicon word count - out_n = 0 # out lexicon word cout + out_n = 0 # out lexicon word cout with open(args.in_lexicon, 'r') as fin, \ open(args.out_lexicon, 'w') as fout: for line in fin: @@ -82,7 +82,10 @@ def main(args): lexicon_table.add(word) out_n += 1 - print(f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}") + print( + f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}" + ) + if __name__ == '__main__': parser = argparse.ArgumentParser( -- GitLab