You need to sign in or sign up before continuing.
提交 4646f7cc 编写于 作者: 小湉湉's avatar 小湉湉

add paddle device set for ort and inference, test=doc

上级 c74fa9ad
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import argparse import argparse
from pathlib import Path from pathlib import Path
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
...@@ -101,6 +102,9 @@ def parse_args(): ...@@ -101,6 +102,9 @@ def parse_args():
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
# frontend # frontend
frontend = get_frontend( frontend = get_frontend(
lang=args.lang, lang=args.lang,
......
...@@ -15,6 +15,7 @@ import argparse ...@@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
...@@ -100,6 +101,9 @@ def parse_args(): ...@@ -100,6 +101,9 @@ def parse_args():
# only inference for models trained with csmsc now # only inference for models trained with csmsc now
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
# frontend # frontend
frontend = get_frontend( frontend = get_frontend(
lang=args.lang, lang=args.lang,
......
...@@ -16,6 +16,7 @@ from pathlib import Path ...@@ -16,6 +16,7 @@ from pathlib import Path
import jsonlines import jsonlines
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
...@@ -25,6 +26,7 @@ from paddlespeech.t2s.utils import str2bool ...@@ -25,6 +26,7 @@ from paddlespeech.t2s.utils import str2bool
def ort_predict(args): def ort_predict(args):
# construct dataset for evaluation # construct dataset for evaluation
with jsonlines.open(args.test_metadata, 'r') as reader: with jsonlines.open(args.test_metadata, 'r') as reader:
test_metadata = list(reader) test_metadata = list(reader)
...@@ -143,6 +145,8 @@ def parse_args(): ...@@ -143,6 +145,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)
......
...@@ -15,6 +15,7 @@ import argparse ...@@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
...@@ -178,6 +179,8 @@ def parse_args(): ...@@ -178,6 +179,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)
......
...@@ -15,6 +15,7 @@ import argparse ...@@ -15,6 +15,7 @@ import argparse
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
import paddle
import soundfile as sf import soundfile as sf
from timer import timer from timer import timer
...@@ -246,6 +247,8 @@ def parse_args(): ...@@ -246,6 +247,8 @@ def parse_args():
def main(): def main():
args = parse_args() args = parse_args()
paddle.set_device(args.device)
ort_predict(args) ort_predict(args)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册