diff --git a/ppcls/data/dataloader/multilabel_dataset.py b/ppcls/data/dataloader/multilabel_dataset.py index 16bd5481f2318633e1d7f876ee7f8c4d359e932d..25dfc12b5730129dcb54bfd6eab95a440560b4aa 100644 --- a/ppcls/data/dataloader/multilabel_dataset.py +++ b/ppcls/data/dataloader/multilabel_dataset.py @@ -42,7 +42,7 @@ class MultiLabelDataset(CommonDataset): self.labels.append(labels) assert os.path.exists(self.images[-1]) if label_ratio: - return np.array(self.labels).mean(0) + return np.array(self.labels).mean(0).astype("float32") def __getitem__(self, idx): try: @@ -53,7 +53,7 @@ class MultiLabelDataset(CommonDataset): img = img.transpose((2, 0, 1)) label = np.array(self.labels[idx]).astype("float32") if self.label_ratio is not None: - return (img, [label, self.label_ratio]) + return (img, np.array([label, self.label_ratio])) else: return (img, label) diff --git a/ppcls/engine/evaluation/classification.py b/ppcls/engine/evaluation/classification.py index 38148240066e592830afdb682b1962a97db1278c..2477b1ffcb9d828eeab0c4dc4ad95a465b04edef 100644 --- a/ppcls/engine/evaluation/classification.py +++ b/ppcls/engine/evaluation/classification.py @@ -84,6 +84,7 @@ def classification_eval(engine, epoch_id=0): # gather Tensor when distributed if paddle.distributed.get_world_size() > 1: label_list = [] + paddle.distributed.all_gather(label_list, batch[1]) labels = paddle.concat(label_list, 0) diff --git a/ppcls/loss/multilabelloss.py b/ppcls/loss/multilabelloss.py index 52c31c7da5274fd8db8c61f59cc48e196cf31dd0..a88d8265a0c1fe9f21708ae27cabf6a5144f052d 100644 --- a/ppcls/loss/multilabelloss.py +++ b/ppcls/loss/multilabelloss.py @@ -38,7 +38,7 @@ class MultiLabelLoss(nn.Layer): def _binary_crossentropy(self, input, target, class_num): if self.weight_ratio: - target, label_ratio = target + target, label_ratio = target[:, 0, :], target[:, 1, :] if self.epsilon is not None: target = self._labelsmoothing(target, class_num) cost = F.binary_cross_entropy_with_logits( diff --git a/ppcls/metric/metrics.py b/ppcls/metric/metrics.py index 989499420c27948d8086e83bd33d246eb443af1c..7fe05be9ea8d390d1c1418fde4d9c13251e22392 100644 --- a/ppcls/metric/metrics.py +++ b/ppcls/metric/metrics.py @@ -363,6 +363,6 @@ class ATTRMetric(nn.Layer): self.threshold = threshold def forward(self, output, target): - metric_dict = get_attr_metrics(target[0].numpy(), + metric_dict = get_attr_metrics(target[:, 0, :].numpy(), output.numpy(), self.threshold) return metric_dict