diff --git a/paddlespeech/t2s/exps/inference.py b/paddlespeech/t2s/exps/inference.py index 7a19a1133dee9b395fc706d41a04c9fe3507f99b..98e73e10269824816bbfcc56a8cc5a9a79ac28e1 100644 --- a/paddlespeech/t2s/exps/inference.py +++ b/paddlespeech/t2s/exps/inference.py @@ -14,6 +14,7 @@ import argparse from pathlib import Path +import paddle import soundfile as sf from timer import timer @@ -101,6 +102,9 @@ def parse_args(): # only inference for models trained with csmsc now def main(): args = parse_args() + + paddle.set_device(args.device) + # frontend frontend = get_frontend( lang=args.lang, diff --git a/paddlespeech/t2s/exps/inference_streaming.py b/paddlespeech/t2s/exps/inference_streaming.py index ef6d1a4ae43764c112336cb6496e7f8875fc4eba..b680f19a99b144c60dc5bca5992f2983509446ff 100644 --- a/paddlespeech/t2s/exps/inference_streaming.py +++ b/paddlespeech/t2s/exps/inference_streaming.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -100,6 +101,9 @@ def parse_args(): # only inference for models trained with csmsc now def main(): args = parse_args() + + paddle.set_device(args.device) + # frontend frontend = get_frontend( lang=args.lang, diff --git a/paddlespeech/t2s/exps/ort_predict.py b/paddlespeech/t2s/exps/ort_predict.py index adbd6809c3df7664e13169743f94c4a6c7f3d259..2e8596deda78de79d6c8d5e09448428c21864ea9 100644 --- a/paddlespeech/t2s/exps/ort_predict.py +++ b/paddlespeech/t2s/exps/ort_predict.py @@ -16,6 +16,7 @@ from pathlib import Path import jsonlines import numpy as np +import paddle import soundfile as sf from timer import timer @@ -25,6 +26,7 @@ from paddlespeech.t2s.utils import str2bool def ort_predict(args): + # construct dataset for evaluation with jsonlines.open(args.test_metadata, 'r') as reader: test_metadata = list(reader) @@ -143,6 +145,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args) diff --git a/paddlespeech/t2s/exps/ort_predict_e2e.py b/paddlespeech/t2s/exps/ort_predict_e2e.py index ae5e900b236f7273603a3f72ec70c9b4b1de203e..a2ef8e4c6da5c9eabf77b50c2b0153077d9426a5 100644 --- a/paddlespeech/t2s/exps/ort_predict_e2e.py +++ b/paddlespeech/t2s/exps/ort_predict_e2e.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -178,6 +179,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args) diff --git a/paddlespeech/t2s/exps/ort_predict_streaming.py b/paddlespeech/t2s/exps/ort_predict_streaming.py index 5568ed39019e71ead6ef24772d552bb1dc73b5c0..5d2c66bc934a3df5a9b23f5b182b12a42a71c64c 100644 --- a/paddlespeech/t2s/exps/ort_predict_streaming.py +++ b/paddlespeech/t2s/exps/ort_predict_streaming.py @@ -15,6 +15,7 @@ import argparse from pathlib import Path import numpy as np +import paddle import soundfile as sf from timer import timer @@ -246,6 +247,8 @@ def parse_args(): def main(): args = parse_args() + paddle.set_device(args.device) + ort_predict(args)