diff --git a/ppcls/static/train.py b/ppcls/static/train.py index 64e0e35d1d71e251bf381bfc671424838193e361..898ebb025ae14d3480d54faa16c259f3fd08c6a7 100755 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -90,8 +90,9 @@ def main(args): fleet.init(is_collective=True) # assign the device - assert global_config[ - "device"] in ["cpu", "gpu", "xpu", "npu", "mlu", "ascend"] + assert global_config["device"] in [ + "cpu", "gpu", "xpu", "npu", "mlu", "ascend", "intel_gpu", "mps" + ] device = paddle.set_device(global_config["device"]) # amp related config