提交 59b3de6a 编写于 作者: Z zhangkeliang

[NPU] test TransformerTTS with NPU

上级 3568bb62
...@@ -42,10 +42,12 @@ from paddlespeech.t2s.training.trainer import Trainer ...@@ -42,10 +42,12 @@ from paddlespeech.t2s.training.trainer import Trainer
def train_sp(args, config): def train_sp(args, config):
# decides device type and whether to run in parallel # decides device type and whether to run in parallel
# setup running environment correctly # setup running environment correctly
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: if paddle.is_compiled_with_cuda() and args.ngpu > 0:
paddle.set_device("cpu")
else:
paddle.set_device("gpu") paddle.set_device("gpu")
elif paddle.is_compiled_with_npu() and args.ngpu > 0:
paddle.set_device("npu")
else:
paddle.set_device("cpu")
world_size = paddle.distributed.get_world_size() world_size = paddle.distributed.get_world_size()
if world_size > 1: if world_size > 1:
paddle.distributed.init_parallel_env() paddle.distributed.init_parallel_env()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册