未验证 提交 319c8059 编写于 作者: MrEO's avatar MrEO 提交者: GitHub

[TTS] Support set device id for tts prediction, test=tts (#3019)

上级 817263fd
......@@ -490,6 +490,7 @@ def get_predictor(
device: str='cpu',
# for gpu
use_trt: bool=False,
device_id: int=0,
# for trt
use_dynamic_shape: bool=True,
min_subgraph_size: int=5,
......@@ -505,6 +506,7 @@ def get_predictor(
params_file (os.PathLike): name of params_file.
device (str): Choose the device you want to run, it can be: cpu/gpu, default is cpu.
use_trt (bool): whether to use TensorRT or not in GPU.
device_id (int): Choose your device id, only valid when the device is gpu, default 0.
use_dynamic_shape (bool): use dynamic shape or not in TensorRT.
use_mkldnn (bool): whether to use MKLDNN or not in CPU.
cpu_threads (int): num of thread when use CPU.
......@@ -521,7 +523,7 @@ def get_predictor(
config.enable_memory_optim()
config.switch_ir_optim(True)
if device == "gpu":
config.enable_use_gpu(100, 0)
config.enable_use_gpu(100, device_id)
else:
config.disable_gpu()
config.set_cpu_math_library_num_threads(cpu_threads)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册