diff --git a/ppcls/configs/PULC/vehicle_attribute/PPLCNet_x1_0_distillation.yaml b/ppcls/configs/PULC/vehicle_attribute/PPLCNet_x1_0_distillation.yaml index 208eb00d0d38bf4a182cac610ca67fc4ccef4625..d098ca81f303fa19a0ebf71145143c2c982dba39 100644 --- a/ppcls/configs/PULC/vehicle_attribute/PPLCNet_x1_0_distillation.yaml +++ b/ppcls/configs/PULC/vehicle_attribute/PPLCNet_x1_0_distillation.yaml @@ -22,6 +22,7 @@ Arch: # if not null, its lengths should be same as models pretrained_list: # if not null, its lengths should be same as models + infer_model_name: "Student" freeze_params_list: - True - False diff --git a/ppcls/data/dataloader/multilabel_dataset.py b/ppcls/data/dataloader/multilabel_dataset.py index f7c7627c55b5b3281ca3f1ef4560e26803389a6d..c67a5ae78f2592bc9be91f5c087ffd9023cddd1b 100644 --- a/ppcls/data/dataloader/multilabel_dataset.py +++ b/ppcls/data/dataloader/multilabel_dataset.py @@ -42,7 +42,7 @@ class MultiLabelDataset(CommonDataset): self.labels.append(labels) assert os.path.exists(self.images[-1]) - if self.label_ratio: + if self.label_ratio is not False: return np.array(self.labels).mean(0).astype("float32") def __getitem__(self, idx): @@ -53,7 +53,7 @@ class MultiLabelDataset(CommonDataset): img = transform(img, self._transform_ops) img = img.transpose((2, 0, 1)) label = np.array(self.labels[idx]).astype("float32") - if self.label_ratio: + if self.label_ratio is not False: return (img, np.array([label, self.label_ratio])) else: return (img, label)