提交 78582b99 编写于 作者: G gaotingquan 提交者: Wei Shengyu

fix: replace use_gpu, etc. by device

上级 8fb6cc53
...@@ -76,8 +76,12 @@ def main(args): ...@@ -76,8 +76,12 @@ def main(args):
if global_config.get("is_distributed", True): if global_config.get("is_distributed", True):
fleet.init(is_collective=True) fleet.init(is_collective=True)
# assign the device # assign the device
use_gpu = global_config.get("use_gpu", True) assert global_config[
"device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"]
device = paddle.set_device(global_config["device"])
# amp related config # amp related config
if 'AMP' in config: if 'AMP' in config:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {
...@@ -89,24 +93,6 @@ def main(args): ...@@ -89,24 +93,6 @@ def main(args):
os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1' os.environ['FLAGS_cudnn_batchnorm_spatial_persistent'] = '1'
paddle.set_flags(AMP_RELATED_FLAGS_SETTING) paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
use_xpu = global_config.get("use_xpu", False)
use_npu = global_config.get("use_npu", False)
use_mlu = global_config.get("use_mlu", False)
assert (
use_gpu + use_xpu + use_npu + use_mlu + use_ascend <= 1
), "gpu, xpu, npu, mlu and ascend can not be true in the same time in static mode!"
if use_gpu:
device = paddle.set_device('gpu')
elif use_xpu:
device = paddle.set_device('xpu')
elif use_npu:
device = paddle.set_device('npu')
elif use_mlu:
device = paddle.set_device('mlu')
else:
device = paddle.set_device('cpu')
# visualDL # visualDL
vdl_writer = None vdl_writer = None
if global_config["use_visualdl"]: if global_config["use_visualdl"]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册