diff --git a/ppcls/static/train.py b/ppcls/static/train.py index 86e832499345f581b1d1dc2c1ef40d6009491622..c58aaf528b264cfe23658996b695846f00f7edc1 100644 --- a/ppcls/static/train.py +++ b/ppcls/static/train.py @@ -93,8 +93,8 @@ def main(args): use_npu = global_config.get("use_npu", False) use_mlu = global_config.get("use_mlu", False) assert ( - use_gpu and use_xpu and use_npu and use_mlu - ) is not True, "gpu, xpu, npu and mlu can not be true in the same time in static mode!" + use_gpu + use_xpu + use_npu + use_mlu + use_ascend <= 1 + ), "gpu, xpu, npu, mlu and ascend can not be true in the same time in static mode!" if use_gpu: device = paddle.set_device('gpu')