From 5b1c5eee1ea30ffba80ba4e87284ba4b71eabecf Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Mon, 4 Jul 2022 13:58:22 +0800 Subject: [PATCH] revert bug of pr #2115 (#2124) * revert bug of pr #2115 * fix yaml --- .../PULC/vehicle_attribute/PPLCNet_x1_0_distillation.yaml | 1 + ppcls/data/dataloader/multilabel_dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ppcls/configs/PULC/vehicle_attribute/PPLCNet_x1_0_distillation.yaml b/ppcls/configs/PULC/vehicle_attribute/PPLCNet_x1_0_distillation.yaml index 208eb00d..d098ca81 100644 --- a/ppcls/configs/PULC/vehicle_attribute/PPLCNet_x1_0_distillation.yaml +++ b/ppcls/configs/PULC/vehicle_attribute/PPLCNet_x1_0_distillation.yaml @@ -22,6 +22,7 @@ Arch: # if not null, its lengths should be same as models pretrained_list: # if not null, its lengths should be same as models + infer_model_name: "Student" freeze_params_list: - True - False diff --git a/ppcls/data/dataloader/multilabel_dataset.py b/ppcls/data/dataloader/multilabel_dataset.py index f7c7627c..c67a5ae7 100644 --- a/ppcls/data/dataloader/multilabel_dataset.py +++ b/ppcls/data/dataloader/multilabel_dataset.py @@ -42,7 +42,7 @@ class MultiLabelDataset(CommonDataset): self.labels.append(labels) assert os.path.exists(self.images[-1]) - if self.label_ratio: + if self.label_ratio is not False: return np.array(self.labels).mean(0).astype("float32") def __getitem__(self, idx): @@ -53,7 +53,7 @@ class MultiLabelDataset(CommonDataset): img = transform(img, self._transform_ops) img = img.transpose((2, 0, 1)) label = np.array(self.labels[idx]).astype("float32") - if self.label_ratio: + if self.label_ratio is not False: return (img, np.array([label, self.label_ratio])) else: return (img, label) -- GitLab