diff --git a/ppcls/engine/engine.py b/ppcls/engine/engine.py index da829282ae348546b85db8a3ef3b4bd97003a3f1..3a8ebdef1ed099ade59818c240fe43200c07cc38 100755 --- a/ppcls/engine/engine.py +++ b/ppcls/engine/engine.py @@ -101,7 +101,7 @@ class Engine(object): # set device assert self.config["Global"][ - "device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"] + "device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps"] self.device = paddle.set_device(self.config["Global"]["device"]) logger.info('train with paddle {} and device {}'.format( paddle.__version__, self.device))