diff --git a/ppcls/data/__init__.py b/ppcls/data/__init__.py index bc61139f45f2f1a13bcd3178edc5f48b36184e87..4cba553262f782b58164fbd0d5e6933a6584204f 100644 --- a/ppcls/data/__init__.py +++ b/ppcls/data/__init__.py @@ -189,3 +189,8 @@ def build_dataloader(config, mode, device, use_dali=False, seed=None): logger.debug("build data_loader({}) success...".format(data_loader)) return data_loader + + +# for PaddleX +ClsDataset = ImageNetDataset +ShiTuDataset = ImageNetDataset