未验证 提交 e0892686 编写于 作者: 小湉湉's avatar 小湉湉 提交者: GitHub

Merge pull request #1727 from yt605155624/refactor_syn_util

[TTS]add paddle device set for ort and inference
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
...@@ -101,6 +102,9 @@ def parse_args(): ...@@ -101,6 +102,9 @@ def parse_args():
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
# frontend # frontend
frontend = get_frontend( frontend = get_frontend(
lang=args.lang, lang=args.lang,
......
...@@ -15,6 +15,7 @@ import argparse ...@@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
...@@ -100,6 +101,9 @@ def parse_args(): ...@@ -100,6 +101,9 @@ def parse_args():
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
# frontend # frontend
frontend = get_frontend( frontend = get_frontend(
lang=args.lang, lang=args.lang,
......
...@@ -16,6 +16,7 @@ from pathlib import Path ...@@ -16,6 +16,7 @@ from pathlib import Path
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
...@@ -25,6 +26,7 @@ from paddlespeech.t2s.utils import str2bool ...@@ -25,6 +26,7 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args): def ort_predict(args):
# construct dataset for evaluation # construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader: with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader) test_metadata = list(reader)
...@@ -143,6 +145,8 @@ def parse_args(): ...@@ -143,6 +145,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)
......
...@@ -15,6 +15,7 @@ import argparse ...@@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
...@@ -178,6 +179,8 @@ def parse_args(): ...@@ -178,6 +179,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)
......
...@@ -15,6 +15,7 @@ import argparse ...@@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
...@@ -246,6 +247,8 @@ def parse_args(): ...@@ -246,6 +247,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册