diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py old mode 100644 new mode 100755 index 95f264058d36cb7408c254d5975787451c24b691..2148fc50aeea0f86052b7382e6f841442c9cb0b9 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -100,7 +100,7 @@ class Engine(object): # set device assert self.config["Global"][ - "device"] in ["cpu", "gpu", "xpu", "npu", "mlu"] + "device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"] self.device = paddle.set_device(self.config["Global"]["device"]) logger.info('train with paddle {} and device {}'.format( paddle.__version__, self.device)) diff --git a/ppcls/static/train.py b/ppcls/static/train.py old mode 100644 new mode 100755 index 86e832499345f581b1d1dc2c1ef40d6009491622..14eb661edf04b1d7e39cffbcc70655847fe03776 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -92,9 +92,10 @@ def main(args): use_xpu = global_config.get("use_xpu", False) use_npu = global_config.get("use_npu", False) use_mlu = global_config.get("use_mlu", False) + use_ascend = global_config.get("use_ascend", False) assert ( - use_gpu and use_xpu and use_npu and use_mlu - ) is not True, "gpu, xpu, npu and mlu can not be true in the same time in static mode!" + use_gpu and use_xpu and use_npu and use_mlu and use_ascend + ) is not True, "gpu, xpu, npu, mlu and ascend can not be true in the same time in static mode!" if use_gpu: device = paddle.set_device('gpu') @@ -104,6 +105,8 @@ def main(args): device = paddle.set_device('npu') elif use_mlu: device = paddle.set_device('mlu') + elif use_ascend: + device = paddle.set_device('ascend') else: device = paddle.set_device('cpu')