提交 e60a63fb 编写于 作者: B BarryKCL

Rollback "get_input_ids"

上级 ab2a1219
...@@ -174,11 +174,11 @@ class Frontend(): ...@@ -174,11 +174,11 @@ class Frontend():
phones_list = [] phones_list = []
for seg in segments: for seg in segments:
phones = [] phones = []
initials = []
finals = []
# Replace all English words in the sentence # Replace all English words in the sentence
seg = re.sub('[a-zA-Z]+', '', seg) seg = re.sub('[a-zA-Z]+', '', seg)
seg_cut = psg.lcut(seg) seg_cut = psg.lcut(seg)
initials = []
finals = []
seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut) seg_cut = self.tone_modifier.pre_merge_for_modify(seg_cut)
if self.g2p_model == "g2pW": if self.g2p_model == "g2pW":
pinyins = self.g2pW_model(seg)[0] pinyins = self.g2pW_model(seg)[0]
...@@ -233,6 +233,7 @@ class Frontend(): ...@@ -233,6 +233,7 @@ class Frontend():
# assert len(sub_initials) == len(sub_finals) == len(word) # assert len(sub_initials) == len(sub_finals) == len(word)
initials = sum(initials, []) initials = sum(initials, [])
finals = sum(finals, []) finals = sum(finals, [])
for c, v in zip(initials, finals): for c, v in zip(initials, finals):
# NOTE: post process for pypinyin outputs # NOTE: post process for pypinyin outputs
# we discriminate i, ii and iii # we discriminate i, ii and iii
...@@ -365,15 +366,15 @@ class Frontend(): ...@@ -365,15 +366,15 @@ class Frontend():
print("----------------------------") print("----------------------------")
return phonemes return phonemes
def get_input_ids( def get_input_ids(self,
self, sentence: str,
sentence: str, merge_sentences: bool=True,
merge_sentences: bool=True, get_tone_ids: bool=False,
get_tone_ids: bool=False, robot: bool=False,
robot: bool=False, print_info: bool=False,
print_info: bool=False, add_blank: bool=False,
add_blank: bool=False, blank_token: str="<pad>",
blank_token: str="<pad>") -> Dict[str, List[paddle.Tensor]]: to_tensor: bool=True) -> Dict[str, List[paddle.Tensor]]:
phonemes = self.get_phonemes( phonemes = self.get_phonemes(
sentence, sentence,
merge_sentences=merge_sentences, merge_sentences=merge_sentences,
...@@ -384,20 +385,22 @@ class Frontend(): ...@@ -384,20 +385,22 @@ class Frontend():
tones = [] tones = []
temp_phone_ids = [] temp_phone_ids = []
temp_tone_ids = [] temp_tone_ids = []
for part_phonemes in phonemes: for part_phonemes in phonemes:
phones, tones = self._get_phone_tone( phones, tones = self._get_phone_tone(
part_phonemes, get_tone_ids=get_tone_ids) part_phonemes, get_tone_ids=get_tone_ids)
if add_blank: if add_blank:
phones = insert_after_character(phones, blank_token) phones = insert_after_character(phones, blank_token)
if tones: if tones:
tone_ids = self._t2id(tones) tone_ids = self._t2id(tones)
tone_ids = paddle.to_tensor(tone_ids) if to_tensor:
tone_ids = paddle.to_tensor(tone_ids)
temp_tone_ids.append(tone_ids) temp_tone_ids.append(tone_ids)
if phones: if phones:
phone_ids = self._p2id(phones) phone_ids = self._p2id(phones)
phone_ids = paddle.to_tensor(phone_ids) # if use paddle.to_tensor() in onnxruntime, the first time will be too low
if to_tensor:
phone_ids = paddle.to_tensor(phone_ids)
temp_phone_ids.append(phone_ids) temp_phone_ids.append(phone_ids)
if temp_tone_ids: if temp_tone_ids:
result["tone_ids"] = temp_tone_ids result["tone_ids"] = temp_tone_ids
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册