diff --git a/ppcls/data/dataset/multilabel_dataset.py b/ppcls/data/dataset/multilabel_dataset.py index 913880c54b888ce80a227f5f352cad38b0be8500..c11555003a4d51a719d8d73892bbc4e835fbc9db 100644 --- a/ppcls/data/dataset/multilabel_dataset.py +++ b/ppcls/data/dataset/multilabel_dataset.py @@ -31,38 +31,8 @@ from ppcls.data.preprocess import transform from ppcls.utils import logger -def create_operators(params): - """ - create operators based on the config - Args: - params(list): a dict list, used to create some operators - """ - assert isinstance(params, list), ('operator config should be a list') - ops = [] - for operator in params: - print(operator) - assert isinstance(operator, - dict) and len(operator) == 1, "yaml format error" - op_name = list(operator)[0] - param = {} if operator[op_name] is None else operator[op_name] - op = getattr(preprocess, op_name)(**param) - ops.append(op) - - return ops - class MultiLabelDataset(Dataset): - def __init__( - self, - image_root, - cls_label_path, - transform_ops=None, ): - self._img_root = image_root - self._cls_path = cls_label_path - if transform_ops: - self._transform_ops = create_operators(transform_ops) - self._dtype = paddle.get_default_dtype() - self._load_anno() def _load_anno(self): assert os.path.exists(self._cls_path) @@ -99,12 +69,7 @@ class MultiLabelDataset(Dataset): rnd_idx = np.random.randint(self.__len__()) return self.__getitem__(rnd_idx) - def __len__(self): - return len(self.images) - @property - def class_num(self): - return len(set(self.labels))