提交 c7d9b115 编写于 作者: H Hui Zhang

format

上级 caf72258
......@@ -12,6 +12,8 @@ exclude =
.git,
# python cache
__pycache__,
# third party
utils/compute-wer.py,
third_party/,
# Provide a comma-separate list of glob patterns to include for checks.
filename =
......
......@@ -40,6 +40,7 @@ from paddlespeech.s2t.utils.utility import UpdateConfig
__all__ = ['ASRExecutor']
@cli_register(
name='paddlespeech.asr', description='Speech to text infer command.')
class ASRExecutor(BaseExecutor):
......@@ -148,7 +149,7 @@ class ASRExecutor(BaseExecutor):
os.path.dirname(os.path.abspath(self.cfg_path)))
logger.info(self.cfg_path)
logger.info(self.ckpt_path)
#Init body.
self.config = CfgNode(new_allowed=True)
self.config.merge_from_file(self.cfg_path)
......@@ -278,7 +279,8 @@ class ASRExecutor(BaseExecutor):
self._outputs["result"] = result_transcripts[0]
elif "conformer" in model_type or "transformer" in model_type:
logger.info(f"we will use the transformer like model : {model_type}")
logger.info(
f"we will use the transformer like model : {model_type}")
try:
result_transcripts = self.model.decode(
audio,
......
......@@ -279,7 +279,7 @@ class U2BaseModel(ASRInterface, nn.Layer):
# TODO(Hui Zhang): if end_flag.sum() == running_size:
if end_flag.cast(paddle.int64).sum() == running_size:
break
# 2.1 Forward decoder step
hyps_mask = subsequent_mask(i).unsqueeze(0).repeat(
running_size, 1, 1).to(device) # (B*N, i, i)
......
......@@ -180,7 +180,7 @@ class CTCDecoder(CTCDecoderBase):
# init once
if self._ext_scorer is not None:
return
if language_model_path != '':
logger.info("begin to initialize the external scorer "
"for decoding")
......
......@@ -47,4 +47,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input input_16k.wav
```
\ No newline at end of file
```
......@@ -48,4 +48,4 @@ paddlespeech_server start --config_file conf/ws_conformer_application.yaml
```
paddlespeech_client asr_online --server_ip 127.0.0.1 --port 8090 --input zh.wav
```
\ No newline at end of file
```
......@@ -305,6 +305,7 @@ class ASRClientExecutor(BaseExecutor):
return res['asr_results']
@cli_client_register(
name='paddlespeech_client.cls', description='visit cls service')
class CLSClientExecutor(BaseExecutor):
......
......@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddlespeech.cli.log import logger
from paddlespeech.s2t.utils.utility import log_add
......
......@@ -36,7 +36,7 @@ class ASRAudioHandler:
x_len = len(samples)
chunk_size = 85 * 16 #80ms, sample_rate = 16kHz
if x_len % chunk_size!= 0:
if x_len % chunk_size != 0:
padding_len_x = chunk_size - x_len % chunk_size
else:
padding_len_x = 0
......@@ -92,7 +92,7 @@ class ASRAudioHandler:
separators=(',', ': '))
await ws.send(audio_info)
msg = await ws.recv()
# decode the bytes to str
msg = json.loads(msg)
logging.info("final receive msg={}".format(msg))
......
......@@ -52,7 +52,7 @@ def evaluate(args):
# acoustic model
am_name = args.am[:args.am.rindex('_')]
am_dataset = args.am[args.am.rindex('_') + 1:]
am_inference = get_am_inference(
am=args.am,
am_config=am_config,
......
......@@ -20,11 +20,11 @@ A few sklearn functions are modified in this script as per requirement.
import argparse
import copy
import warnings
from distutils.util import strtobool
import numpy as np
import scipy
import sklearn
from distutils.util import strtobool
from scipy import linalg
from scipy import sparse
from scipy.sparse.csgraph import connected_components
......
......@@ -2,6 +2,7 @@
import argparse
from collections import Counter
def main(args):
counter = Counter()
with open(args.text, 'r') as fin, open(args.lexicon, 'w') as fout:
......@@ -12,7 +13,7 @@ def main(args):
words = text.split()
else:
words = line.split()
counter.update(words)
for word in counter:
......@@ -20,21 +21,16 @@ def main(args):
fout.write(f"{word}\t{val}\n")
fout.flush()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='text(line:utt1 中国 人) to lexicon(line:中国 中 国).')
parser.add_argument(
'--has_key',
default=True,
help='text path, with utt or not')
'--has_key', default=True, help='text path, with utt or not')
parser.add_argument(
'--text',
required=True,
help='text path. line: utt1 中国 人 or 中国 人')
'--text', required=True, help='text path. line: utt1 中国 人 or 中国 人')
parser.add_argument(
'--lexicon',
required=True,
help='lexicon path. line:中国 中 国')
'--lexicon', required=True, help='lexicon path. line:中国 中 国')
args = parser.parse_args()
print(args)
......
#!/usr/bin/env python3
# modify from https://sites.google.com/site/homepageoffuyanwei/Home/remarksandexcellentdiscussion/page-2
class Word:
def __init__(self,text = '',freq = 0):
self.text = text
self.freq = freq
self.length = len(text)
class Word:
def __init__(self, text='', freq=0):
self.text = text
self.freq = freq
self.length = len(text)
class Chunk:
def __init__(self,w1,w2 = None,w3 = None):
self.words = []
self.words.append(w1)
if w2:
self.words.append(w2)
if w3:
self.words.append(w3)
def __init__(self, w1, w2=None, w3=None):
self.words = []
self.words.append(w1)
if w2:
self.words.append(w2)
if w3:
self.words.append(w3)
#计算chunk的总长度
def totalWordLength(self):
length = 0
for word in self.words:
length += len(word.text)
return length
def totalWordLength(self):
length = 0
for word in self.words:
length += len(word.text)
return length
#计算平均长度
def averageWordLength(self):
return float(self.totalWordLength()) / float(len(self.words))
def averageWordLength(self):
return float(self.totalWordLength()) / float(len(self.words))
#计算标准差
def standardDeviation(self):
average = self.averageWordLength()
sum = 0.0
for word in self.words:
tmp = (len(word.text) - average)
sum += float(tmp) * float(tmp)
return sum
def standardDeviation(self):
average = self.averageWordLength()
sum = 0.0
for word in self.words:
tmp = (len(word.text) - average)
sum += float(tmp) * float(tmp)
return sum
#自由语素度
def wordFrequency(self):
sum = 0
for word in self.words:
sum += word.freq
return sum
class ComplexCompare:
sum = 0
for word in self.words:
sum += word.freq
return sum
class ComplexCompare:
def takeHightest(self, chunks, comparator):
i = 1
for j in range(1, len(chunks)):
rlt = comparator(chunks[j], chunks[0])
if rlt > 0:
i = 0
if rlt >= 0:
chunks[i], chunks[j] = chunks[j], chunks[i]
i += 1
i = 1
for j in range(1, len(chunks)):
rlt = comparator(chunks[j], chunks[0])
if rlt > 0:
i = 0
if rlt >= 0:
chunks[i], chunks[j] = chunks[j], chunks[i]
i += 1
return chunks[0:i]
#以下四个函数是mmseg算法的四种过滤原则,核心算法
def mmFilter(self, chunks):
def comparator(a,b):
return a.totalWordLength() - b.totalWordLength()
return self.takeHightest(chunks, comparator)
def lawlFilter(self,chunks):
def comparator(a,b):
return a.averageWordLength() - b.averageWordLength()
return self.takeHightest(chunks,comparator)
def svmlFilter(self,chunks):
def comparator(a,b):
return b.standardDeviation() - a.standardDeviation()
return self.takeHightest(chunks, comparator)
def logFreqFilter(self,chunks):
def comparator(a,b):
return a.wordFrequency() - b.wordFrequency()
return self.takeHightest(chunks, comparator)
def comparator(a, b):
return a.totalWordLength() - b.totalWordLength()
return self.takeHightest(chunks, comparator)
def lawlFilter(self, chunks):
def comparator(a, b):
return a.averageWordLength() - b.averageWordLength()
return self.takeHightest(chunks, comparator)
def svmlFilter(self, chunks):
def comparator(a, b):
return b.standardDeviation() - a.standardDeviation()
return self.takeHightest(chunks, comparator)
def logFreqFilter(self, chunks):
def comparator(a, b):
return a.wordFrequency() - b.wordFrequency()
return self.takeHightest(chunks, comparator)
#加载词组字典和字符字典
dictWord = {}
maxWordLength = 0
def loadDictChars(filepath):
global maxWordLength
def loadDictChars(filepath):
global maxWordLength
fsock = open(filepath)
for line in fsock:
freq, word = line.split()
word = word.strip()
dictWord[word] = (len(word), int(freq))
maxWordLength = len(word) if maxWordLength < len(word) else maxWordLength
fsock.close()
def loadDictWords(filepath):
global maxWordLength
fsock = open(filepath)
for line in fsock.readlines():
dictWord[word] = (len(word), int(freq))
maxWordLength = len(word) if maxWordLength < len(
word) else maxWordLength
fsock.close()
def loadDictWords(filepath):
global maxWordLength
fsock = open(filepath)
for line in fsock.readlines():
word = line.strip()
dictWord[word] = (len(word), 0)
maxWordLength = len(word) if maxWordLength < len(word) else maxWordLength
fsock.close()
dictWord[word] = (len(word), 0)
maxWordLength = len(word) if maxWordLength < len(
word) else maxWordLength
fsock.close()
#判断该词word是否在字典dictWord中
def getDictWord(word):
result = dictWord.get(word)
if result:
return Word(word, result[1])
return None
def getDictWord(word):
result = dictWord.get(word)
if result:
return Word(word, result[1])
return None
#开始加载字典
def run():
from os.path import join, dirname
loadDictChars(join(dirname(__file__), 'data', 'chars.dic'))
loadDictWords(join(dirname(__file__), 'data', 'words.dic'))
class Analysis:
def __init__(self, text):
def run():
from os.path import join, dirname
loadDictChars(join(dirname(__file__), 'data', 'chars.dic'))
loadDictWords(join(dirname(__file__), 'data', 'words.dic'))
class Analysis:
def __init__(self, text):
self.text = text
self.cacheSize = 3
self.pos = 0
self.textLength = len(self.text)
self.cache = []
self.cacheIndex = 0
self.complexCompare = ComplexCompare()
self.cacheSize = 3
self.pos = 0
self.textLength = len(self.text)
self.cache = []
self.cacheIndex = 0
self.complexCompare = ComplexCompare()
#简单小技巧,用到个缓存,不知道具体有没有用处
for i in range(self.cacheSize):
for i in range(self.cacheSize):
self.cache.append([-1, Word()])
#控制字典只加载一次
if not dictWord:
run()
def __iter__(self):
while True:
token = self.getNextToken()
if token is None:
raise StopIteration
yield token
def getNextChar(self):
return self.text[self.pos]
def __iter__(self):
while True:
token = self.getNextToken()
if token == None:
raise StopIteration
yield token
def getNextChar(self):
return self.text[self.pos]
#判断该字符是否是中文字符(不包括中文标点)
def isChineseChar(self,charater):
return 0x4e00 <= ord(charater) < 0x9fa6
def isChineseChar(self, charater):
return 0x4e00 <= ord(charater) < 0x9fa6
#判断是否是ASCII码
def isASCIIChar(self, ch):
import string
if ch in string.whitespace:
return False
if ch in string.punctuation:
return False
def isASCIIChar(self, ch):
import string
if ch in string.whitespace:
return False
if ch in string.punctuation:
return False
return ch in string.printable
#得到下一个切割结果
def getNextToken(self):
while self.pos < self.textLength:
if self.isChineseChar(self.getNextChar()):
token = self.getChineseWords()
else :
token = self.getASCIIWords()+'/'
if len(token) > 0:
def getNextToken(self):
while self.pos < self.textLength:
if self.isChineseChar(self.getNextChar()):
token = self.getChineseWords()
else:
token = self.getASCIIWords() + '/'
if len(token) > 0:
return token
return None
#切割出非中文词
def getASCIIWords(self):
# Skip pre-word whitespaces and punctuations
#跳过中英文标点和空格
while self.pos < self.textLength:
ch = self.getNextChar()
if self.isASCIIChar(ch) or self.isChineseChar(ch):
break
self.pos += 1
while self.pos < self.textLength:
ch = self.getNextChar()
if self.isASCIIChar(ch) or self.isChineseChar(ch):
break
self.pos += 1
#得到英文单词的起始位置
start = self.pos
start = self.pos
#找出英文单词的结束位置
while self.pos < self.textLength:
ch = self.getNextChar()
if not self.isASCIIChar(ch):
break
self.pos += 1
end = self.pos
while self.pos < self.textLength:
ch = self.getNextChar()
if not self.isASCIIChar(ch):
break
self.pos += 1
end = self.pos
#Skip chinese word whitespaces and punctuations
#跳过中英文标点和空格
while self.pos < self.textLength:
ch = self.getNextChar()
if self.isASCIIChar(ch) or self.isChineseChar(ch):
break
self.pos += 1
while self.pos < self.textLength:
ch = self.getNextChar()
if self.isASCIIChar(ch) or self.isChineseChar(ch):
break
self.pos += 1
#返回英文单词
return self.text[start:end]
return self.text[start:end]
#切割出中文词,并且做处理,用上述4种方法
def getChineseWords(self):
chunks = self.createChunks()
if len(chunks) > 1:
chunks = self.complexCompare.mmFilter(chunks)
if len(chunks) > 1:
chunks = self.complexCompare.lawlFilter(chunks)
if len(chunks) > 1:
chunks = self.complexCompare.svmlFilter(chunks)
if len(chunks) > 1:
chunks = self.complexCompare.logFreqFilter(chunks)
if len(chunks) == 0 :
return ''
def getChineseWords(self):
chunks = self.createChunks()
if len(chunks) > 1:
chunks = self.complexCompare.mmFilter(chunks)
if len(chunks) > 1:
chunks = self.complexCompare.lawlFilter(chunks)
if len(chunks) > 1:
chunks = self.complexCompare.svmlFilter(chunks)
if len(chunks) > 1:
chunks = self.complexCompare.logFreqFilter(chunks)
if len(chunks) == 0:
return ''
#最后只有一种切割方法
word = chunks[0].words
token = ""
length = 0
for x in word:
if x.length != -1:
token += x.text + "/"
length += len(x.text)
self.pos += length
return token
word = chunks[0].words
token = ""
length = 0
for x in word:
if x.length != -1:
token += x.text + "/"
length += len(x.text)
self.pos += length
return token
#三重循环来枚举切割方法,这里也可以运用递归来实现
def createChunks(self):
chunks = []
originalPos = self.pos
words1 = self.getMatchChineseWords()
for word1 in words1:
self.pos += len(word1.text)
if self.pos < self.textLength:
words2 = self.getMatchChineseWords()
for word2 in words2:
self.pos += len(word2.text)
if self.pos < self.textLength:
words3 = self.getMatchChineseWords()
for word3 in words3:
def createChunks(self):
chunks = []
originalPos = self.pos
words1 = self.getMatchChineseWords()
for word1 in words1:
self.pos += len(word1.text)
if self.pos < self.textLength:
words2 = self.getMatchChineseWords()
for word2 in words2:
self.pos += len(word2.text)
if self.pos < self.textLength:
words3 = self.getMatchChineseWords()
for word3 in words3:
# print(word3.length, word3.text)
if word3.length == -1:
chunk = Chunk(word1,word2)
if word3.length == -1:
chunk = Chunk(word1, word2)
# print("Ture")
else :
chunk = Chunk(word1,word2,word3)
chunks.append(chunk)
elif self.pos == self.textLength:
chunks.append(Chunk(word1,word2))
self.pos -= len(word2.text)
elif self.pos == self.textLength:
chunks.append(Chunk(word1))
self.pos -= len(word1.text)
self.pos = originalPos
return chunks
else:
chunk = Chunk(word1, word2, word3)
chunks.append(chunk)
elif self.pos == self.textLength:
chunks.append(Chunk(word1, word2))
self.pos -= len(word2.text)
elif self.pos == self.textLength:
chunks.append(Chunk(word1))
self.pos -= len(word1.text)
self.pos = originalPos
return chunks
#运用正向最大匹配算法结合字典来切割中文文本
def getMatchChineseWords(self):
def getMatchChineseWords(self):
#use cache,check it
for i in range(self.cacheSize):
if self.cache[i][0] == self.pos:
return self.cache[i][1]
originalPos = self.pos
words = []
index = 0
while self.pos < self.textLength:
if index >= maxWordLength :
break
if not self.isChineseChar(self.getNextChar()):
break
self.pos += 1
index += 1
text = self.text[originalPos:self.pos]
word = getDictWord(text)
if word:
words.append(word)
self.pos = originalPos
for i in range(self.cacheSize):
if self.cache[i][0] == self.pos:
return self.cache[i][1]
originalPos = self.pos
words = []
index = 0
while self.pos < self.textLength:
if index >= maxWordLength:
break
if not self.isChineseChar(self.getNextChar()):
break
self.pos += 1
index += 1
text = self.text[originalPos:self.pos]
word = getDictWord(text)
if word:
words.append(word)
self.pos = originalPos
#没有词则放置个‘X’,将文本长度标记为-1
if not words:
word = Word()
word.length = -1
word.text = 'X'
words.append(word)
self.cache[self.cacheIndex] = (self.pos,words)
self.cacheIndex += 1
if self.cacheIndex >= self.cacheSize:
self.cacheIndex = 0
return words
if __name__=="__main__":
def cuttest(text):
if not words:
word = Word()
word.length = -1
word.text = 'X'
words.append(word)
self.cache[self.cacheIndex] = (self.pos, words)
self.cacheIndex += 1
if self.cacheIndex >= self.cacheSize:
self.cacheIndex = 0
return words
if __name__ == "__main__":
def cuttest(text):
#cut = Analysis(text)
tmp=""
tmp = ""
try:
for word in iter(Analysis(text)):
tmp += word
......@@ -310,71 +320,73 @@ if __name__=="__main__":
print("================================")
cuttest(u"研究生命来源")
cuttest(u"南京市长江大桥欢迎您")
cuttest(u"请把手抬高一点儿")
cuttest(u"长春市长春节致词。")
cuttest(u"长春市长春药店。")
cuttest(u"我的和服务必在明天做好。")
cuttest(u"我发现有很多人喜欢他。")
cuttest(u"我喜欢看电视剧大长今。")
cuttest(u"半夜给拎起来陪看欧洲杯糊着两眼半晌没搞明白谁和谁踢。")
cuttest(u"李智伟高高兴兴以及王晓薇出去玩,后来智伟和晓薇又单独去玩了。")
cuttest(u"一次性交出去很多钱。 ")
cuttest(u"这是一个伸手不见五指的黑夜。我叫孙悟空,我爱北京,我爱Python和C++。")
cuttest(u"我不喜欢日本和服。")
cuttest(u"雷猴回归人间。")
cuttest(u"工信处女干事每月经过下属科室都要亲口交代24口交换机等技术性器件的安装工作")
cuttest(u"我需要廉租房")
cuttest(u"永和服装饰品有限公司")
cuttest(u"我爱北京天安门")
cuttest(u"abc")
cuttest(u"隐马尔可夫")
cuttest(u"雷猴是个好网站")
cuttest(u"“Microsoft”一词由“MICROcomputer(微型计算机)”和“SOFTware(软件)”两部分组成")
cuttest(u"草泥马和欺实马是今年的流行词汇")
cuttest(u"伊藤洋华堂总府店")
cuttest(u"中国科学院计算技术研究所")
cuttest(u"罗密欧与朱丽叶")
cuttest(u"我购买了道具和服装")
cuttest(u"PS: 我觉得开源有一个好处,就是能够敦促自己不断改进,避免敞帚自珍")
cuttest(u"湖北省石首市")
cuttest(u"总经理完成了这件事情")
cuttest(u"电脑修好了")
cuttest(u"做好了这件事情就一了百了了")
cuttest(u"人们审美的观点是不同的")
cuttest(u"我们买了一个美的空调")
cuttest(u"线程初始化时我们要注意")
cuttest(u"一个分子是由好多原子组织成的")
cuttest(u"祝你马到功成")
cuttest(u"他掉进了无底洞里")
cuttest(u"中国的首都是北京")
cuttest(u"孙君意")
cuttest(u"外交部发言人马朝旭")
cuttest(u"领导人会议和第四届东亚峰会")
cuttest(u"在过去的这五年")
cuttest(u"还需要很长的路要走")
cuttest(u"60周年首都阅兵")
cuttest(u"你好人们审美的观点是不同的")
cuttest(u"买水果然后来世博园")
cuttest(u"买水果然后去世博园")
cuttest(u"但是后来我才知道你是对的")
cuttest(u"存在即合理")
cuttest(u"的的的的的在的的的的就以和和和")
cuttest(u"I love你,不以为耻,反以为rong")
cuttest(u" ")
cuttest(u"")
cuttest(u"hello你好人们审美的观点是不同的")
cuttest(u"很好但主要是基于网页形式")
cuttest(u"hello你好人们审美的观点是不同的")
cuttest(u"为什么我不能拥有想要的生活")
cuttest(u"后来我才")
cuttest(u"此次来中国是为了")
cuttest(u"使用了它就可以解决一些问题")
cuttest(u",使用了它就可以解决一些问题")
cuttest(u"其实使用了它就可以解决一些问题")
cuttest(u"好人使用了它就可以解决一些问题")
cuttest(u"是因为和国家")
cuttest(u"老年搜索还支持")
cuttest(u"干脆就把那部蒙人的闲法给废了拉倒!RT @laoshipukong : 27日,全国人大常委会第三次审议侵权责任法草案,删除了有关医疗损害责任“举证倒置”的规定。在医患纠纷中本已处于弱势地位的消费者由此将陷入万劫不复的境地。 ")
cuttest(u"南京市长江大桥欢迎您")
cuttest(u"请把手抬高一点儿")
cuttest(u"长春市长春节致词。")
cuttest(u"长春市长春药店。")
cuttest(u"我的和服务必在明天做好。")
cuttest(u"我发现有很多人喜欢他。")
cuttest(u"我喜欢看电视剧大长今。")
cuttest(u"半夜给拎起来陪看欧洲杯糊着两眼半晌没搞明白谁和谁踢。")
cuttest(u"李智伟高高兴兴以及王晓薇出去玩,后来智伟和晓薇又单独去玩了。")
cuttest(u"一次性交出去很多钱。 ")
cuttest(u"这是一个伸手不见五指的黑夜。我叫孙悟空,我爱北京,我爱Python和C++。")
cuttest(u"我不喜欢日本和服。")
cuttest(u"雷猴回归人间。")
cuttest(u"工信处女干事每月经过下属科室都要亲口交代24口交换机等技术性器件的安装工作")
cuttest(u"我需要廉租房")
cuttest(u"永和服装饰品有限公司")
cuttest(u"我爱北京天安门")
cuttest(u"abc")
cuttest(u"隐马尔可夫")
cuttest(u"雷猴是个好网站")
cuttest(u"“Microsoft”一词由“MICROcomputer(微型计算机)”和“SOFTware(软件)”两部分组成")
cuttest(u"草泥马和欺实马是今年的流行词汇")
cuttest(u"伊藤洋华堂总府店")
cuttest(u"中国科学院计算技术研究所")
cuttest(u"罗密欧与朱丽叶")
cuttest(u"我购买了道具和服装")
cuttest(u"PS: 我觉得开源有一个好处,就是能够敦促自己不断改进,避免敞帚自珍")
cuttest(u"湖北省石首市")
cuttest(u"总经理完成了这件事情")
cuttest(u"电脑修好了")
cuttest(u"做好了这件事情就一了百了了")
cuttest(u"人们审美的观点是不同的")
cuttest(u"我们买了一个美的空调")
cuttest(u"线程初始化时我们要注意")
cuttest(u"一个分子是由好多原子组织成的")
cuttest(u"祝你马到功成")
cuttest(u"他掉进了无底洞里")
cuttest(u"中国的首都是北京")
cuttest(u"孙君意")
cuttest(u"外交部发言人马朝旭")
cuttest(u"领导人会议和第四届东亚峰会")
cuttest(u"在过去的这五年")
cuttest(u"还需要很长的路要走")
cuttest(u"60周年首都阅兵")
cuttest(u"你好人们审美的观点是不同的")
cuttest(u"买水果然后来世博园")
cuttest(u"买水果然后去世博园")
cuttest(u"但是后来我才知道你是对的")
cuttest(u"存在即合理")
cuttest(u"的的的的的在的的的的就以和和和")
cuttest(u"I love你,不以为耻,反以为rong")
cuttest(u" ")
cuttest(u"")
cuttest(u"hello你好人们审美的观点是不同的")
cuttest(u"很好但主要是基于网页形式")
cuttest(u"hello你好人们审美的观点是不同的")
cuttest(u"为什么我不能拥有想要的生活")
cuttest(u"后来我才")
cuttest(u"此次来中国是为了")
cuttest(u"使用了它就可以解决一些问题")
cuttest(u",使用了它就可以解决一些问题")
cuttest(u"其实使用了它就可以解决一些问题")
cuttest(u"好人使用了它就可以解决一些问题")
cuttest(u"是因为和国家")
cuttest(u"老年搜索还支持")
cuttest(
u"干脆就把那部蒙人的闲法给废了拉倒!RT @laoshipukong : 27日,全国人大常委会第三次审议侵权责任法草案,删除了有关医疗损害责任“举证倒置”的规定。在医患纠纷中本已处于弱势地位的消费者由此将陷入万劫不复的境地。 "
)
cuttest("2022年12月30日是星期几?")
cuttest("二零二二年十二月三十日是星期几?")
\ No newline at end of file
cuttest("二零二二年十二月三十日是星期几?")
......@@ -183,4 +183,4 @@ data/
├── lexiconp_disambig.txt
├── lexiconp.txt
└── units.list
```
\ No newline at end of file
```
......@@ -26,9 +26,9 @@ import argparse
import os
import re
import subprocess
from distutils.util import strtobool
import numpy as np
from distutils.util import strtobool
FILE_IDS = re.compile(r"(?<=Speaker Diarization for).+(?=\*\*\*)")
SCORED_SPEAKER_TIME = re.compile(r"(?<=SCORED SPEAKER TIME =)[\d.]+")
......
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# CopyRight WeNet Apache-2.0 License
import re, sys, unicodedata
import codecs
import re
import sys
import unicodedata
remove_tag = True
spacelist= [' ', '\t', '\r', '\n']
puncts = ['!', ',', '?',
'、', '。', '!', ',', ';', '?',
':', '「', '」', '︰', '『', '』', '《', '》']
spacelist = [' ', '\t', '\r', '\n']
puncts = [
'!', ',', '?', '、', '。', '!', ',', ';', '?', ':', '「', '」', '︰', '『', '』',
'《', '》'
]
def characterize(string):
res = []
i = 0
while i < len(string):
char = string[i]
if char in puncts:
i += 1
continue
cat1 = unicodedata.category(char)
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
i += 1
continue
if cat1 == 'Lo': # letter-other
res.append(char)
i += 1
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' '
if char == '<': sep = '>'
j = i + 1
while j < len(string):
c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c == sep):
break
j += 1
if j < len(string) and string[j] == '>':
j += 1
res.append(string[i:j])
i = j
return res
def characterize(string) :
res = []
i = 0
while i < len(string):
char = string[i]
if char in puncts:
i += 1
continue
cat1 = unicodedata.category(char)
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if cat1 == 'Zs' or cat1 == 'Cn' or char in spacelist: # space or not assigned
i += 1
continue
if cat1 == 'Lo': # letter-other
res.append(char)
i += 1
else:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep = ' '
if char == '<': sep = '>'
j = i+1
while j < len(string):
c = string[j]
if ord(c) >= 128 or (c in spacelist) or (c==sep):
break
j += 1
if j < len(string) and string[j] == '>':
j += 1
res.append(string[i:j])
i = j
return res
def stripoff_tags(x):
if not x: return ''
chars = []
i = 0; T=len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
i += 1
i += 1
else:
chars.append(x[i])
i += 1
return ''.join(chars)
if not x: return ''
chars = []
i = 0
T = len(x)
while i < T:
if x[i] == '<':
while i < T and x[i] != '>':
i += 1
i += 1
else:
chars.append(x[i])
i += 1
return ''.join(chars)
def normalize(sentence, ignore_words, cs, split=None):
......@@ -65,436 +70,485 @@ def normalize(sentence, ignore_words, cs, split=None):
for token in sentence:
x = token
if not cs:
x = x.upper()
x = x.upper()
if x in ignore_words:
continue
continue
if remove_tag:
x = stripoff_tags(x)
x = stripoff_tags(x)
if not x:
continue
continue
if split and x in split:
new_sentence += split[x]
new_sentence += split[x]
else:
new_sentence.append(x)
new_sentence.append(x)
return new_sentence
class Calculator :
def __init__(self) :
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec) :
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab) :
self.space.append([])
for row in self.space :
for element in row :
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec) :
row.append({'dist' : 0, 'error' : 'non'})
for i in range(len(lab)) :
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)) :
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab :
if token not in self.data and len(token) > 0 :
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
for token in rec :
if token not in self.data and len(token) > 0 :
self.data[token] = {'all' : 0, 'cor' : 0, 'sub' : 0, 'ins' : 0, 'del' : 0}
# Computing edit distance
for i, lab_token in enumerate(lab) :
for j, rec_token in enumerate(rec) :
if i == 0 or j == 0 :
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i-1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist :
min_dist = dist
min_error = error
dist = self.space[i][j-1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist :
min_dist = dist
min_error = error
if lab_token == rec_token :
dist = self.space[i-1][j-1]['dist'] + self.cost['cor']
error = 'cor'
else :
dist = self.space[i-1][j-1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist :
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {'lab':[], 'rec':[], 'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
i = len(lab) - 1
j = len(rec) - 1
while True :
if self.space[i][j]['error'] == 'cor' : # correct
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub' : # substitution
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del' : # deletion
if len(lab[i]) > 0 :
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, "")
i = i - 1
elif self.space[i][j]['error'] == 'ins' : # insertion
if len(rec[j]) > 0 :
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, "")
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non' : # starting point
break
else : # shouldn't reach here
print('this should not happen , i = {i} , j = {j} , error = {error}'.format(i = i, j = j, error = self.space[i][j]['error']))
return result
def overall(self) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in self.data :
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data) :
result = {'all':0, 'cor':0, 'sub':0, 'ins':0, 'del':0}
for token in data :
if token in self.data :
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self) :
return list(self.data.keys())
class Calculator:
def __init__(self):
self.data = {}
self.space = []
self.cost = {}
self.cost['cor'] = 0
self.cost['sub'] = 1
self.cost['del'] = 1
self.cost['ins'] = 1
def calculate(self, lab, rec):
# Initialization
lab.insert(0, '')
rec.insert(0, '')
while len(self.space) < len(lab):
self.space.append([])
for row in self.space:
for element in row:
element['dist'] = 0
element['error'] = 'non'
while len(row) < len(rec):
row.append({'dist': 0, 'error': 'non'})
for i in range(len(lab)):
self.space[i][0]['dist'] = i
self.space[i][0]['error'] = 'del'
for j in range(len(rec)):
self.space[0][j]['dist'] = j
self.space[0][j]['error'] = 'ins'
self.space[0][0]['error'] = 'non'
for token in lab:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
for token in rec:
if token not in self.data and len(token) > 0:
self.data[token] = {
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
# Computing edit distance
for i, lab_token in enumerate(lab):
for j, rec_token in enumerate(rec):
if i == 0 or j == 0:
continue
min_dist = sys.maxsize
min_error = 'none'
dist = self.space[i - 1][j]['dist'] + self.cost['del']
error = 'del'
if dist < min_dist:
min_dist = dist
min_error = error
dist = self.space[i][j - 1]['dist'] + self.cost['ins']
error = 'ins'
if dist < min_dist:
min_dist = dist
min_error = error
if lab_token == rec_token:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['cor']
error = 'cor'
else:
dist = self.space[i - 1][j - 1]['dist'] + self.cost['sub']
error = 'sub'
if dist < min_dist:
min_dist = dist
min_error = error
self.space[i][j]['dist'] = min_dist
self.space[i][j]['error'] = min_error
# Tracing back
result = {
'lab': [],
'rec': [],
'all': 0,
'cor': 0,
'sub': 0,
'ins': 0,
'del': 0
}
i = len(lab) - 1
j = len(rec) - 1
while True:
if self.space[i][j]['error'] == 'cor': # correct
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['cor'] = self.data[lab[i]]['cor'] + 1
result['all'] = result['all'] + 1
result['cor'] = result['cor'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'sub': # substitution
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['sub'] = self.data[lab[i]]['sub'] + 1
result['all'] = result['all'] + 1
result['sub'] = result['sub'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, rec[j])
i = i - 1
j = j - 1
elif self.space[i][j]['error'] == 'del': # deletion
if len(lab[i]) > 0:
self.data[lab[i]]['all'] = self.data[lab[i]]['all'] + 1
self.data[lab[i]]['del'] = self.data[lab[i]]['del'] + 1
result['all'] = result['all'] + 1
result['del'] = result['del'] + 1
result['lab'].insert(0, lab[i])
result['rec'].insert(0, "")
i = i - 1
elif self.space[i][j]['error'] == 'ins': # insertion
if len(rec[j]) > 0:
self.data[rec[j]]['ins'] = self.data[rec[j]]['ins'] + 1
result['ins'] = result['ins'] + 1
result['lab'].insert(0, "")
result['rec'].insert(0, rec[j])
j = j - 1
elif self.space[i][j]['error'] == 'non': # starting point
break
else: # shouldn't reach here
print(
'this should not happen , i = {i} , j = {j} , error = {error}'.
format(i=i, j=j, error=self.space[i][j]['error']))
return result
def overall(self):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def cluster(self, data):
result = {'all': 0, 'cor': 0, 'sub': 0, 'ins': 0, 'del': 0}
for token in data:
if token in self.data:
result['all'] = result['all'] + self.data[token]['all']
result['cor'] = result['cor'] + self.data[token]['cor']
result['sub'] = result['sub'] + self.data[token]['sub']
result['ins'] = result['ins'] + self.data[token]['ins']
result['del'] = result['del'] + self.data[token]['del']
return result
def keys(self):
return list(self.data.keys())
def width(string):
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word) :
unicode_names = [ unicodedata.name(char) for char in word ]
for i in reversed(range(len(unicode_names))) :
if unicode_names[i].startswith('DIGIT') : # 1
unicode_names[i] = 'Number' # 'DIGIT'
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or
unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')) :
# 明 / 郎
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or
unicode_names[i].startswith('LATIN SMALL LETTER')) :
# A / a
unicode_names[i] = 'English' # 'LATIN LETTER'
elif unicode_names[i].startswith('HIRAGANA LETTER') : # は こ め
unicode_names[i] = 'Japanese' # 'GANA LETTER'
elif (unicode_names[i].startswith('AMPERSAND') or
unicode_names[i].startswith('APOSTROPHE') or
unicode_names[i].startswith('COMMERCIAL AT') or
unicode_names[i].startswith('DEGREE CELSIUS') or
unicode_names[i].startswith('EQUALS SIGN') or
unicode_names[i].startswith('FULL STOP') or
unicode_names[i].startswith('HYPHEN-MINUS') or
unicode_names[i].startswith('LOW LINE') or
unicode_names[i].startswith('NUMBER SIGN') or
unicode_names[i].startswith('PLUS SIGN') or
unicode_names[i].startswith('SEMICOLON')) :
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
else :
return 'Other'
if len(unicode_names) == 0 :
return 'Other'
if len(unicode_names) == 1 :
return unicode_names[0]
for i in range(len(unicode_names)-1) :
if unicode_names[i] != unicode_names[i+1] :
return 'Other'
return unicode_names[0]
def usage() :
print("compute-wer.py : compute word error rate (WER) and align recognition results and references.")
print(" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer")
def default_cluster(word):
unicode_names = [unicodedata.name(char) for char in word]
for i in reversed(range(len(unicode_names))):
if unicode_names[i].startswith('DIGIT'): # 1
unicode_names[i] = 'Number' # 'DIGIT'
elif (unicode_names[i].startswith('CJK UNIFIED IDEOGRAPH') or
unicode_names[i].startswith('CJK COMPATIBILITY IDEOGRAPH')):
# 明 / 郎
unicode_names[i] = 'Mandarin' # 'CJK IDEOGRAPH'
elif (unicode_names[i].startswith('LATIN CAPITAL LETTER') or
unicode_names[i].startswith('LATIN SMALL LETTER')):
# A / a
unicode_names[i] = 'English' # 'LATIN LETTER'
elif unicode_names[i].startswith('HIRAGANA LETTER'): # は こ め
unicode_names[i] = 'Japanese' # 'GANA LETTER'
elif (unicode_names[i].startswith('AMPERSAND') or
unicode_names[i].startswith('APOSTROPHE') or
unicode_names[i].startswith('COMMERCIAL AT') or
unicode_names[i].startswith('DEGREE CELSIUS') or
unicode_names[i].startswith('EQUALS SIGN') or
unicode_names[i].startswith('FULL STOP') or
unicode_names[i].startswith('HYPHEN-MINUS') or
unicode_names[i].startswith('LOW LINE') or
unicode_names[i].startswith('NUMBER SIGN') or
unicode_names[i].startswith('PLUS SIGN') or
unicode_names[i].startswith('SEMICOLON')):
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del unicode_names[i]
else:
return 'Other'
if len(unicode_names) == 0:
return 'Other'
if len(unicode_names) == 1:
return unicode_names[0]
for i in range(len(unicode_names) - 1):
if unicode_names[i] != unicode_names[i + 1]:
return 'Other'
return unicode_names[0]
def usage():
print(
"compute-wer.py : compute word error rate (WER) and align recognition results and references."
)
print(
" usage : python compute-wer.py [--cs={0,1}] [--cluster=foo] [--ig=ignore_file] [--char={0,1}] [--v={0,1}] [--padding-symbol={space,underline}] test.ref test.hyp > test.wer"
)
if __name__ == '__main__':
if len(sys.argv) == 1 :
usage()
sys.exit(0)
calculator = Calculator()
cluster_file = ''
ignore_words = set()
tochar = False
verbose= 1
padding_symbol= ' '
case_sensitive = False
max_words_per_line = sys.maxsize
split = None
while len(sys.argv) > 3:
a = '--maxw='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):]
del sys.argv[1]
max_words_per_line = int(b)
continue
a = '--rt='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
remove_tag = (b == 'true') or (b != '0')
continue
a = '--cs='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
case_sensitive = (b == 'true') or (b != '0')
continue
a = '--cluster='
if sys.argv[1].startswith(a):
cluster_file = sys.argv[1][len(a):]
del sys.argv[1]
continue
a = '--splitfile='
if sys.argv[1].startswith(a):
split_file = sys.argv[1][len(a):]
del sys.argv[1]
split = dict()
with codecs.open(split_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
words = line.strip().split()
if len(words) >= 2:
split[words[0]] = words[1:]
continue
a = '--ig='
if sys.argv[1].startswith(a):
ignore_file = sys.argv[1][len(a):]
del sys.argv[1]
with codecs.open(ignore_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
line = line.strip()
if len(line) > 0:
ignore_words.add(line)
continue
a = '--char='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
tochar = (b == 'true') or (b != '0')
continue
a = '--v='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
verbose=0
try:
verbose=int(b)
except:
if b == 'true' or b != '0':
verbose = 1
continue
a = '--padding-symbol='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
if b == 'space':
padding_symbol= ' '
elif b == 'underline':
padding_symbol= '_'
continue
if True or sys.argv[1].startswith('-'):
#ignore invalid switch
del sys.argv[1]
continue
if len(sys.argv) == 1:
usage()
sys.exit(0)
calculator = Calculator()
cluster_file = ''
ignore_words = set()
tochar = False
verbose = 1
padding_symbol = ' '
case_sensitive = False
max_words_per_line = sys.maxsize
split = None
while len(sys.argv) > 3:
a = '--maxw='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):]
del sys.argv[1]
max_words_per_line = int(b)
continue
a = '--rt='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
remove_tag = (b == 'true') or (b != '0')
continue
a = '--cs='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
case_sensitive = (b == 'true') or (b != '0')
continue
a = '--cluster='
if sys.argv[1].startswith(a):
cluster_file = sys.argv[1][len(a):]
del sys.argv[1]
continue
a = '--splitfile='
if sys.argv[1].startswith(a):
split_file = sys.argv[1][len(a):]
del sys.argv[1]
split = dict()
with codecs.open(split_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
words = line.strip().split()
if len(words) >= 2:
split[words[0]] = words[1:]
continue
a = '--ig='
if sys.argv[1].startswith(a):
ignore_file = sys.argv[1][len(a):]
del sys.argv[1]
with codecs.open(ignore_file, 'r', 'utf-8') as fh:
for line in fh: # line in unicode
line = line.strip()
if len(line) > 0:
ignore_words.add(line)
continue
a = '--char='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
tochar = (b == 'true') or (b != '0')
continue
a = '--v='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
verbose = 0
try:
verbose = int(b)
except:
if b == 'true' or b != '0':
verbose = 1
continue
a = '--padding-symbol='
if sys.argv[1].startswith(a):
b = sys.argv[1][len(a):].lower()
del sys.argv[1]
if b == 'space':
padding_symbol = ' '
elif b == 'underline':
padding_symbol = '_'
continue
if True or sys.argv[1].startswith('-'):
#ignore invalid switch
del sys.argv[1]
continue
if not case_sensitive:
ig=set([w.upper() for w in ignore_words])
ignore_words = ig
if not case_sensitive:
ig = set([w.upper() for w in ignore_words])
ignore_words = ig
default_clusters = {}
default_words = {}
default_clusters = {}
default_words = {}
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
rec_set = {}
if split and not case_sensitive:
newsplit = dict()
for w in split:
words = split[w]
for i in range(len(words)):
words[i] = words[i].upper()
newsplit[w.upper()] = words
split = newsplit
ref_file = sys.argv[1]
hyp_file = sys.argv[2]
rec_set = {}
if split and not case_sensitive:
newsplit = dict()
for w in split:
words = split[w]
for i in range(len(words)):
words[i] = words[i].upper()
newsplit[w.upper()] = words
split = newsplit
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
for line in fh:
with codecs.open(hyp_file, 'r', 'utf-8') as fh:
for line in fh:
if tochar:
array = characterize(line)
else:
array = line.strip().split()
if len(array) == 0: continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive,
split)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8'):
if tochar:
array = characterize(line)
else:
array = line.strip().split()
if len(array)==0: continue
array = line.rstrip('\n').split()
if len(array) == 0: continue
fid = array[0]
rec_set[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
if fid not in rec_set:
continue
lab = normalize(array[1:], ignore_words, case_sensitive, split)
rec = rec_set[fid]
if verbose:
print('\nutt: %s' % fid)
# compute error rate on the interaction of reference file and hyp file
for line in open(ref_file, 'r', encoding='utf-8') :
if tochar:
array = characterize(line)
else:
array = line.rstrip('\n').split()
if len(array)==0: continue
fid = array[0]
if fid not in rec_set:
continue
lab = normalize(array[1:], ignore_words, case_sensitive, split)
rec = rec_set[fid]
if verbose:
print('\nutt: %s' % fid)
for word in rec + lab:
if word not in default_words:
default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters:
default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name]:
default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name
for word in rec + lab :
if word not in default_words :
default_cluster_name = default_cluster(word)
if default_cluster_name not in default_clusters :
default_clusters[default_cluster_name] = {}
if word not in default_clusters[default_cluster_name] :
default_clusters[default_cluster_name][word] = 1
default_words[word] = default_cluster_name
result = calculator.calculate(lab, rec)
if verbose:
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
wer = 0.0
print('WER: %4.2f %%' % wer, end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
space = {}
space['lab'] = []
space['rec'] = []
for idx in range(len(result['lab'])):
len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx])
length = max(len_lab, len_rec)
space['lab'].append(length - len_lab)
space['rec'].append(length - len_rec)
upper_lab = len(result['lab'])
upper_rec = len(result['rec'])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end=' ')
else:
print('lab:', end=' ')
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result['lab'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['lab'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print()
if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end=' ')
else:
print('rec:', end=' ')
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result['rec'][idx]
print('{token}'.format(token=token), end='')
for n in range(space['rec'][idx]):
print(padding_symbol, end='')
print(' ', end='')
print('\n', end='\n')
lab1 = lab2
rec1 = rec2
result = calculator.calculate(lab, rec)
if verbose:
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('WER: %4.2f %%' % wer, end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
space = {}
space['lab'] = []
space['rec'] = []
for idx in range(len(result['lab'])) :
len_lab = width(result['lab'][idx])
len_rec = width(result['rec'][idx])
length = max(len_lab, len_rec)
space['lab'].append(length-len_lab)
space['rec'].append(length-len_rec)
upper_lab = len(result['lab'])
upper_rec = len(result['rec'])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1:
print('lab(%s):' % fid.encode('utf-8'), end = ' ')
else:
print('lab:', end = ' ')
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
token = result['lab'][idx]
print('{token}'.format(token = token), end = '')
for n in range(space['lab'][idx]) :
print(padding_symbol, end = '')
print(' ',end='')
print()
if verbose > 1:
print('rec(%s):' % fid.encode('utf-8'), end = ' ')
else:
print('rec:', end = ' ')
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
token = result['rec'][idx]
print('{token}'.format(token = token), end = '')
for n in range(space['rec'][idx]) :
print(padding_symbol, end = '')
print(' ',end='')
print('\n', end='\n')
lab1 = lab2
rec1 = rec2
if verbose:
print('===========================================================================')
print()
result = calculator.overall()
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('Overall -> %4.2f %%' % wer, end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
if not verbose:
print()
print(
'==========================================================================='
)
print()
if verbose:
for cluster_id in default_clusters :
result = calculator.cluster([ k for k in default_clusters[cluster_id] ])
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
result = calculator.overall()
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
if len(cluster_file) > 0 : # compute separated WERs for word clusters
cluster_id = ''
cluster = []
for line in open(cluster_file, 'r', encoding='utf-8') :
for token in line.decode('utf-8').rstrip('\n').split() :
# end of cluster reached, like </Keyword>
if token[0:2] == '</' and token[len(token)-1] == '>' and \
token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster)
if result['all'] != 0 :
wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
else :
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'], result['ins']))
cluster_id = ''
cluster = []
# begin of cluster reached, like <Keyword>
elif token[0] == '<' and token[len(token)-1] == '>' and \
cluster_id == '' :
cluster_id = token.lstrip('<').rstrip('>')
cluster = []
# general terms, like WEATHER / CAR / ...
else :
cluster.append(token)
print()
print('===========================================================================')
\ No newline at end of file
print('Overall -> %4.2f %%' % wer, end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if not verbose:
print()
if verbose:
for cluster_id in default_clusters:
result = calculator.cluster(
[k for k in default_clusters[cluster_id]])
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'], result['del'],
result['ins']))
if len(cluster_file) > 0: # compute separated WERs for word clusters
cluster_id = ''
cluster = []
for line in open(cluster_file, 'r', encoding='utf-8'):
for token in line.decode('utf-8').rstrip('\n').split():
# end of cluster reached, like </Keyword>
if token[0:2] == '</' and token[len(token)-1] == '>' and \
token.lstrip('</').rstrip('>') == cluster_id :
result = calculator.cluster(cluster)
if result['all'] != 0:
wer = float(result['ins'] + result['sub'] + result[
'del']) * 100.0 / result['all']
else:
wer = 0.0
print('%s -> %4.2f %%' % (cluster_id, wer), end=' ')
print('N=%d C=%d S=%d D=%d I=%d' %
(result['all'], result['cor'], result['sub'],
result['del'], result['ins']))
cluster_id = ''
cluster = []
# begin of cluster reached, like <Keyword>
elif token[0] == '<' and token[len(token)-1] == '>' and \
cluster_id == '' :
cluster_id = token.lstrip('<').rstrip('>')
cluster = []
# general terms, like WEATHER / CAR / ...
else:
cluster.append(token)
print()
print(
'==========================================================================='
)
import os
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import jsonlines
def trans_hyp(origin_hyp,
trans_hyp = None,
trans_hyp_sclite = None):
def trans_hyp(origin_hyp, trans_hyp=None, trans_hyp_sclite=None):
"""
Args:
origin_hyp: The input json file which contains the model output
......@@ -17,19 +27,18 @@ def trans_hyp(origin_hyp,
with open(origin_hyp, "r+", encoding="utf8") as f:
for item in jsonlines.Reader(f):
input_dict[item["utt"]] = item["hyps"][0]
if trans_hyp is not None:
if trans_hyp is not None:
with open(trans_hyp, "w+", encoding="utf8") as f:
for key in input_dict.keys():
f.write(key + " " + input_dict[key] + "\n")
if trans_hyp_sclite is not None:
if trans_hyp_sclite is not None:
with open(trans_hyp_sclite, "w+") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n"
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line)
def trans_ref(origin_ref,
trans_ref = None,
trans_ref_sclite = None):
def trans_ref(origin_ref, trans_ref=None, trans_ref_sclite=None):
"""
Args:
origin_hyp: The input json file which contains the model output
......@@ -49,42 +58,48 @@ def trans_ref(origin_ref,
if trans_ref_sclite is not None:
with open(trans_ref_sclite, "w") as f:
for key in input_dict.keys():
line = input_dict[key] + "(" + key + ".wav" +")" + "\n"
line = input_dict[key] + "(" + key + ".wav" + ")" + "\n"
f.write(line)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='format hyp file for compute CER/WER', add_help=True)
parser = argparse.ArgumentParser(
prog='format hyp file for compute CER/WER', add_help=True)
parser.add_argument(
'--origin_hyp',
type=str,
default = None,
help='origin hyp file')
'--origin_hyp', type=str, default=None, help='origin hyp file')
parser.add_argument(
'--trans_hyp', type=str, default = None, help='hyp file for caculating CER/WER')
'--trans_hyp',
type=str,
default=None,
help='hyp file for caculating CER/WER')
parser.add_argument(
'--trans_hyp_sclite', type=str, default = None, help='hyp file for caculating CER/WER by sclite')
'--trans_hyp_sclite',
type=str,
default=None,
help='hyp file for caculating CER/WER by sclite')
parser.add_argument(
'--origin_ref',
type=str,
default = None,
help='origin ref file')
'--origin_ref', type=str, default=None, help='origin ref file')
parser.add_argument(
'--trans_ref', type=str, default = None, help='ref file for caculating CER/WER')
'--trans_ref',
type=str,
default=None,
help='ref file for caculating CER/WER')
parser.add_argument(
'--trans_ref_sclite', type=str, default = None, help='ref file for caculating CER/WER by sclite')
'--trans_ref_sclite',
type=str,
default=None,
help='ref file for caculating CER/WER by sclite')
parser_args = parser.parse_args()
if parser_args.origin_hyp is not None:
trans_hyp(
origin_hyp = parser_args.origin_hyp,
trans_hyp = parser_args.trans_hyp,
trans_hyp_sclite = parser_args.trans_hyp_sclite, )
origin_hyp=parser_args.origin_hyp,
trans_hyp=parser_args.trans_hyp,
trans_hyp_sclite=parser_args.trans_hyp_sclite, )
if parser_args.origin_ref is not None:
trans_ref(
origin_ref = parser_args.origin_ref,
trans_ref = parser_args.trans_ref,
trans_ref_sclite = parser_args.trans_ref_sclite, )
origin_ref=parser_args.origin_ref,
trans_ref=parser_args.trans_ref,
trans_ref_sclite=parser_args.trans_ref_sclite, )
......@@ -35,7 +35,7 @@ def main(args):
# used to filter polyphone and invalid word
lexicon_table = set()
in_n = 0 # in lexicon word count
out_n = 0 # out lexicon word cout
out_n = 0 # out lexicon word cout
with open(args.in_lexicon, 'r') as fin, \
open(args.out_lexicon, 'w') as fout:
for line in fin:
......@@ -82,7 +82,10 @@ def main(args):
lexicon_table.add(word)
out_n += 1
print(f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}")
print(
f"Filter lexicon by unit table: filter out {in_n - out_n}, {out_n}/{in_n}"
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册