From a472af86ec1fffb46e7ef716cbe293acbb74d826 Mon Sep 17 00:00:00 2001 From: huangqipeng Date: Mon, 14 Mar 2022 15:48:26 +0800 Subject: [PATCH] feat: support mlu device and amp of mlu --- ppcls/engine/engine.py | 7 +++++-- ppcls/static/train.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index 61d09ff8..8d077b9d 100644 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -92,7 +92,7 @@ class Engine(object): self.vdl_writer = LogWriter(logdir=vdl_writer_path) # set device - assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu"] + assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu", "npu", "mlu"] self.device = paddle.set_device(self.config["Global"]["device"]) logger.info('train with paddle {} and device {}'.format( paddle.__version__, self.device)) @@ -108,9 +108,12 @@ class Engine(object): self.use_dynamic_loss_scaling = False if self.amp: AMP_RELATED_FLAGS_SETTING = { - 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_max_inplace_grad_add': 8, } + if paddle.is_compiled_with_cuda(): + AMP_RELATED_FLAGS_SETTING.update({ + 'FLAGS_cudnn_batchnorm_spatial_persistent': 1 + }) paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) if "class_num" in config["Global"]: diff --git a/ppcls/static/train.py b/ppcls/static/train.py index 9c03598b..1961dfaf 100644 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -91,9 +91,10 @@ def main(args): use_xpu = global_config.get("use_xpu", False) use_npu = global_config.get("use_npu", False) + use_mlu = global_config.get("use_mlu", False) assert ( - use_gpu and use_xpu and use_npu - ) is not True, "gpu, xpu and npu can not be true in the same time in static mode!" + use_gpu and use_xpu and use_npu and use_mlu + ) is not True, "gpu, xpu, npu and mlu can not be true in the same time in static mode!" if use_gpu: device = paddle.set_device('gpu') @@ -101,6 +102,8 @@ def main(args): device = paddle.set_device('xpu') elif use_npu: device = paddle.set_device('npu') + elif use_mlu: + device = paddle.set_device('mlu') else: device = paddle.set_device('cpu') -- GitLab