未验证 提交 45f51651 编写于 作者: L liangym 提交者: GitHub

Merge pull request #2129 from lym0302/onnx_gpu

[server]specify id
...@@ -30,7 +30,9 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None): ...@@ -30,7 +30,9 @@ def get_sess(model_path: Optional[os.PathLike]=None, sess_conf: dict=None):
# "gpu:0" # "gpu:0"
providers = ['CPUExecutionProvider'] providers = ['CPUExecutionProvider']
if "gpu" in sess_conf.get("device", ""): if "gpu" in sess_conf.get("device", ""):
providers = ['CUDAExecutionProvider'] device_id = int(sess_conf["device"].split(":")[1])
providers = [('CUDAExecutionProvider', {'device_id': device_id})]
# fastspeech2/mb_melgan can't use trt now! # fastspeech2/mb_melgan can't use trt now!
if sess_conf.get("use_trt", 0): if sess_conf.get("use_trt", 0):
providers = ['TensorrtExecutionProvider'] providers = ['TensorrtExecutionProvider']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册