提交 a472af86 编写于 作者: H huangqipeng 提交者: 努力努力在努力丶

feat: support mlu device and amp of mlu

上级 e2e492f9
...@@ -92,7 +92,7 @@ class Engine(object): ...@@ -92,7 +92,7 @@ class Engine(object):
self.vdl_writer = LogWriter(logdir=vdl_writer_path) self.vdl_writer = LogWriter(logdir=vdl_writer_path)
# set device # set device
assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu"] assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu", "mlu"]
self.device = paddle.set_device(self.config["Global"]["device"]) self.device = paddle.set_device(self.config["Global"]["device"])
logger.info('train with paddle {} and device {}'.format( logger.info('train with paddle {} and device {}'.format(
paddle.__version__, self.device)) paddle.__version__, self.device))
...@@ -108,9 +108,12 @@ class Engine(object): ...@@ -108,9 +108,12 @@ class Engine(object):
self.use_dynamic_loss_scaling = False self.use_dynamic_loss_scaling = False
if self.amp: if self.amp:
AMP_RELATED_FLAGS_SETTING = { AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_max_inplace_grad_add': 8, 'FLAGS_max_inplace_grad_add': 8,
} }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
})
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
if "class_num" in config["Global"]: if "class_num" in config["Global"]:
......
...@@ -91,9 +91,10 @@ def main(args): ...@@ -91,9 +91,10 @@ def main(args):
use_xpu = global_config.get("use_xpu", False) use_xpu = global_config.get("use_xpu", False)
use_npu = global_config.get("use_npu", False) use_npu = global_config.get("use_npu", False)
use_mlu = global_config.get("use_mlu", False)
assert ( assert (
use_gpu and use_xpu and use_npu use_gpu and use_xpu and use_npu and use_mlu
) is not True, "gpu, xpu and npu can not be true in the same time in static mode!" ) is not True, "gpu, xpu, npu and mlu can not be true in the same time in static mode!"
if use_gpu: if use_gpu:
device = paddle.set_device('gpu') device = paddle.set_device('gpu')
...@@ -101,6 +102,8 @@ def main(args): ...@@ -101,6 +102,8 @@ def main(args):
device = paddle.set_device('xpu') device = paddle.set_device('xpu')
elif use_npu: elif use_npu:
device = paddle.set_device('npu') device = paddle.set_device('npu')
elif use_mlu:
device = paddle.set_device('mlu')
else: else:
device = paddle.set_device('cpu') device = paddle.set_device('cpu')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册