未验证 提交 5b1c5eee 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

revert bug of pr #2115 (#2124)

* revert bug of pr #2115

* fix yaml
上级 d0ab4781
...@@ -22,6 +22,7 @@ Arch: ...@@ -22,6 +22,7 @@ Arch:
# if not null, its lengths should be same as models # if not null, its lengths should be same as models
pretrained_list: pretrained_list:
# if not null, its lengths should be same as models # if not null, its lengths should be same as models
infer_model_name: "Student"
freeze_params_list: freeze_params_list:
- True - True
- False - False
......
...@@ -42,7 +42,7 @@ class MultiLabelDataset(CommonDataset): ...@@ -42,7 +42,7 @@ class MultiLabelDataset(CommonDataset):
self.labels.append(labels) self.labels.append(labels)
assert os.path.exists(self.images[-1]) assert os.path.exists(self.images[-1])
if self.label_ratio: if self.label_ratio is not False:
return np.array(self.labels).mean(0).astype("float32") return np.array(self.labels).mean(0).astype("float32")
def __getitem__(self, idx): def __getitem__(self, idx):
...@@ -53,7 +53,7 @@ class MultiLabelDataset(CommonDataset): ...@@ -53,7 +53,7 @@ class MultiLabelDataset(CommonDataset):
img = transform(img, self._transform_ops) img = transform(img, self._transform_ops)
img = img.transpose((2, 0, 1)) img = img.transpose((2, 0, 1))
label = np.array(self.labels[idx]).astype("float32") label = np.array(self.labels[idx]).astype("float32")
if self.label_ratio: if self.label_ratio is not False:
return (img, np.array([label, self.label_ratio])) return (img, np.array([label, self.label_ratio]))
else: else:
return (img, label) return (img, label)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册