未验证 提交 319c8059 编写于 作者: MrEO's avatar MrEO 提交者: GitHub

[TTS] Support set device id for tts prediction, test=tts (#3019)

上级 817263fd
...@@ -490,6 +490,7 @@ def get_predictor( ...@@ -490,6 +490,7 @@ def get_predictor(
device: str='cpu', device: str='cpu',
# for gpu # for gpu
use_trt: bool=False, use_trt: bool=False,
device_id: int=0,
# for trt # for trt
use_dynamic_shape: bool=True, use_dynamic_shape: bool=True,
min_subgraph_size: int=5, min_subgraph_size: int=5,
...@@ -505,6 +506,7 @@ def get_predictor( ...@@ -505,6 +506,7 @@ def get_predictor(
params_file (os.PathLike): name of params_file. params_file (os.PathLike): name of params_file.
device (str): Choose the device you want to run, it can be: cpu/gpu, default is cpu. device (str): Choose the device you want to run, it can be: cpu/gpu, default is cpu.
use_trt (bool): whether to use TensorRT or not in GPU. use_trt (bool): whether to use TensorRT or not in GPU.
device_id (int): Choose your device id, only valid when the device is gpu, default 0.
use_dynamic_shape (bool): use dynamic shape or not in TensorRT. use_dynamic_shape (bool): use dynamic shape or not in TensorRT.
use_mkldnn (bool): whether to use MKLDNN or not in CPU. use_mkldnn (bool): whether to use MKLDNN or not in CPU.
cpu_threads (int): num of thread when use CPU. cpu_threads (int): num of thread when use CPU.
...@@ -521,7 +523,7 @@ def get_predictor( ...@@ -521,7 +523,7 @@ def get_predictor(
config.enable_memory_optim() config.enable_memory_optim()
config.switch_ir_optim(True) config.switch_ir_optim(True)
if device == "gpu": if device == "gpu":
config.enable_use_gpu(100, 0) config.enable_use_gpu(100, device_id)
else: else:
config.disable_gpu() config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads) config.set_cpu_math_library_num_threads(cpu_threads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册