未验证 提交 8b1c1ec4 编写于 作者: L liangym 提交者: GitHub

Merge branch 'PaddlePaddle:develop' into update_engine

......@@ -82,7 +82,7 @@ class CommonTaskResource:
self.model_tag = model_tag
self.version = version
self.res_dict = self.pretrained_models[model_tag][version]
self.format_path(self.res_dict)
self._format_path(self.res_dict)
self.res_dir = self._fetch(self.res_dict,
self._get_model_dir(model_type))
else:
......@@ -90,19 +90,10 @@ class CommonTaskResource:
self.voc_model_tag = model_tag
self.voc_version = version
self.voc_res_dict = self.pretrained_models[model_tag][version]
self.format_path(self.voc_res_dict)
self._format_path(self.voc_res_dict)
self.voc_res_dir = self._fetch(self.voc_res_dict,
self._get_model_dir(model_type))
@staticmethod
def format_path(res_dict: Dict[str, str]):
for k, v in res_dict.items():
if '/' in v:
if v.startswith('https://') or v.startswith('http://'):
continue
else:
res_dict[k] = os.path.join(*(v.split('/')))
@staticmethod
def get_model_class(model_name) -> List[object]:
"""Dynamic import model class.
......@@ -231,3 +222,12 @@ class CommonTaskResource:
os.PathLike: Directory of model resource.
"""
return download_and_decompress(res_dict, target_dir)
@staticmethod
def _format_path(res_dict: Dict[str, str]):
for k, v in res_dict.items():
if isinstance(v, str) and '/' in v:
if v.startswith('https://') or v.startswith('http://'):
continue
else:
res_dict[k] = os.path.join(*(v.split('/')))
......@@ -90,7 +90,7 @@ def parse_args():
default=False,
help="whether use streaming acoustic model")
parser.add_argument(
"--chunk_size", type=int, default=42, help="chunk size of am streaming")
"--block_size", type=int, default=42, help="block size of am streaming")
parser.add_argument(
"--pad_size", type=int, default=12, help="pad size of am streaming")
......@@ -169,7 +169,7 @@ def main():
N = 0
T = 0
chunk_size = args.chunk_size
block_size = args.block_size
pad_size = args.pad_size
get_tone_ids = False
for utt_id, sentence in sentences:
......@@ -189,7 +189,7 @@ def main():
am_encoder_infer_predictor, input=phones)
if args.am_streaming:
hss = get_chunks(orig_hs, chunk_size, pad_size)
hss = get_chunks(orig_hs, block_size, pad_size)
chunk_num = len(hss)
mel_list = []
for i, hs in enumerate(hss):
......@@ -211,7 +211,7 @@ def main():
sub_mel = sub_mel[pad_size:]
else:
# 倒数几块的右侧也可能没有 pad 够
sub_mel = sub_mel[pad_size:(chunk_size + pad_size) -
sub_mel = sub_mel[pad_size:(block_size + pad_size) -
sub_mel.shape[0]]
mel_list.append(sub_mel)
mel = np.concatenate(mel_list, axis=0)
......
......@@ -97,7 +97,7 @@ def ort_predict(args):
T = 0
merge_sentences = True
get_tone_ids = False
chunk_size = args.chunk_size
block_size = args.block_size
pad_size = args.pad_size
for utt_id, sentence in sentences:
......@@ -115,7 +115,7 @@ def ort_predict(args):
orig_hs = am_encoder_infer_sess.run(
None, input_feed={'text': phone_ids})
if args.am_streaming:
hss = get_chunks(orig_hs[0], chunk_size, pad_size)
hss = get_chunks(orig_hs[0], block_size, pad_size)
chunk_num = len(hss)
mel_list = []
for i, hs in enumerate(hss):
......@@ -139,7 +139,7 @@ def ort_predict(args):
sub_mel = sub_mel[pad_size:]
else:
# 倒数几块的右侧也可能没有 pad 够
sub_mel = sub_mel[pad_size:(chunk_size + pad_size) -
sub_mel = sub_mel[pad_size:(block_size + pad_size) -
sub_mel.shape[0]]
mel_list.append(sub_mel)
mel = np.concatenate(mel_list, axis=0)
......@@ -236,7 +236,7 @@ def parse_args():
default=False,
help="whether use streaming acoustic model")
parser.add_argument(
"--chunk_size", type=int, default=42, help="chunk size of am streaming")
"--block_size", type=int, default=42, help="block size of am streaming")
parser.add_argument(
"--pad_size", type=int, default=12, help="pad size of am streaming")
......
......@@ -75,13 +75,13 @@ def denorm(data, mean, std):
return data * std + mean
def get_chunks(data, chunk_size: int, pad_size: int):
def get_chunks(data, block_size: int, pad_size: int):
data_len = data.shape[1]
chunks = []
n = math.ceil(data_len / chunk_size)
n = math.ceil(data_len / block_size)
for i in range(n):
start = max(0, i * chunk_size - pad_size)
end = min((i + 1) * chunk_size + pad_size, data_len)
start = max(0, i * block_size - pad_size)
end = min((i + 1) * block_size + pad_size, data_len)
chunks.append(data[:, start:end, :])
return chunks
......
......@@ -133,7 +133,7 @@ def evaluate(args):
N = 0
T = 0
chunk_size = args.chunk_size
block_size = args.block_size
pad_size = args.pad_size
for utt_id, sentence in sentences:
......@@ -153,7 +153,7 @@ def evaluate(args):
# acoustic model
orig_hs = am_encoder_infer(phone_ids)
if args.am_streaming:
hss = get_chunks(orig_hs, chunk_size, pad_size)
hss = get_chunks(orig_hs, block_size, pad_size)
chunk_num = len(hss)
mel_list = []
for i, hs in enumerate(hss):
......@@ -171,7 +171,7 @@ def evaluate(args):
sub_mel = sub_mel[pad_size:]
else:
# 倒数几块的右侧也可能没有 pad 够
sub_mel = sub_mel[pad_size:(chunk_size + pad_size) -
sub_mel = sub_mel[pad_size:(block_size + pad_size) -
sub_mel.shape[0]]
mel_list.append(sub_mel)
mel = paddle.concat(mel_list, axis=0)
......@@ -277,7 +277,7 @@ def parse_args():
default=False,
help="whether use streaming acoustic model")
parser.add_argument(
"--chunk_size", type=int, default=42, help="chunk size of am streaming")
"--block_size", type=int, default=42, help="block size of am streaming")
parser.add_argument(
"--pad_size", type=int, default=12, help="pad size of am streaming")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册