未验证 提交 15fd4414 编写于 作者: F Felix 提交者: GitHub

Update multilabel_dataset.py

上级 fa8dbc7f
...@@ -31,38 +31,8 @@ from ppcls.data.preprocess import transform ...@@ -31,38 +31,8 @@ from ppcls.data.preprocess import transform
from ppcls.utils import logger from ppcls.utils import logger
def create_operators(params):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(params, list), ('operator config should be a list')
ops = []
for operator in params:
print(operator)
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
op = getattr(preprocess, op_name)(**param)
ops.append(op)
return ops
class MultiLabelDataset(Dataset): class MultiLabelDataset(Dataset):
def __init__(
self,
image_root,
cls_label_path,
transform_ops=None, ):
self._img_root = image_root
self._cls_path = cls_label_path
if transform_ops:
self._transform_ops = create_operators(transform_ops)
self._dtype = paddle.get_default_dtype()
self._load_anno()
def _load_anno(self): def _load_anno(self):
assert os.path.exists(self._cls_path) assert os.path.exists(self._cls_path)
...@@ -99,12 +69,7 @@ class MultiLabelDataset(Dataset): ...@@ -99,12 +69,7 @@ class MultiLabelDataset(Dataset):
rnd_idx = np.random.randint(self.__len__()) rnd_idx = np.random.randint(self.__len__())
return self.__getitem__(rnd_idx) return self.__getitem__(rnd_idx)
def __len__(self):
return len(self.images)
@property
def class_num(self):
return len(set(self.labels))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册