提交 e55177c3 编写于 作者: Q QingshuChen

speedyspeech support kunlun

上级 b4387ab6
...@@ -174,12 +174,17 @@ def main(): ...@@ -174,12 +174,17 @@ def main():
parser.add_argument( parser.add_argument(
"--inference-dir", type=str, help="dir to save inference models") "--inference-dir", type=str, help="dir to save inference models")
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") "--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu.")
parser.add_argument(
"--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
if args.ngpu == 0: if args.ngpu == 0:
paddle.set_device("cpu") if args.nxpu == 0:
paddle.set_device("cpu")
else:
paddle.set_device("xpu")
elif args.ngpu > 0: elif args.ngpu > 0:
paddle.set_device("gpu") paddle.set_device("gpu")
else: else:
......
...@@ -46,7 +46,10 @@ def train_sp(args, config): ...@@ -46,7 +46,10 @@ def train_sp(args, config):
# setup running environment correctly # setup running environment correctly
world_size = paddle.distributed.get_world_size() world_size = paddle.distributed.get_world_size()
if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0: if (not paddle.is_compiled_with_cuda()) or args.ngpu == 0:
paddle.set_device("cpu") if (not paddle.is_compiled_with_xpu()) or args.nxpu == 0:
paddle.set_device("cpu")
else:
paddle.set_device("xpu")
else: else:
paddle.set_device("gpu") paddle.set_device("gpu")
if world_size > 1: if world_size > 1:
...@@ -185,7 +188,9 @@ def main(): ...@@ -185,7 +188,9 @@ def main():
parser.add_argument("--dev-metadata", type=str, help="dev data.") parser.add_argument("--dev-metadata", type=str, help="dev data.")
parser.add_argument("--output-dir", type=str, help="output dir.") parser.add_argument("--output-dir", type=str, help="output dir.")
parser.add_argument( parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu.") "--nxpu", type=int, default=0, help="if nxpu == 0 and ngpu == 0, use cpu.")
parser.add_argument(
"--ngpu", type=int, default=1, help="if ngpu == 0, use cpu or xpu")
parser.add_argument( parser.add_argument(
"--use-relative-path", "--use-relative-path",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册