未验证 提交 a0eb34a6 编写于 作者: R ronnywang 提交者: GitHub

Add npu supporting (#1324)

上级 cc00a51a
......@@ -91,7 +91,7 @@ class Engine(object):
self.vdl_writer = LogWriter(logdir=vdl_writer_path)
# set device
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu"]
self.device = paddle.set_device(self.config["Global"]["device"])
logger.info('train with paddle {} and device {}'.format(
paddle.__version__, self.device))
......
......@@ -91,14 +91,17 @@ def main(args):
os.environ[k] = AMP_RELATED_FLAGS_SETTING[k]
use_xpu = global_config.get("use_xpu", False)
use_npu = global_config.get("use_npu", False)
assert (
use_gpu and use_xpu
) is not True, "gpu and xpu can not be true in the same time in static mode!"
use_gpu and use_xpu and use_npu
) is not True, "gpu, xpu and npu can not be true in the same time in static mode!"
if use_gpu:
device = paddle.set_device('gpu')
elif use_xpu:
device = paddle.set_device('xpu')
elif use_npu:
device = paddle.set_device('npu')
else:
device = paddle.set_device('cpu')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册