提交 05ecf1d0 编写于 作者: Z zhiboniu

multi-card eval support

上级 824746b3
......@@ -42,7 +42,7 @@ class MultiLabelDataset(CommonDataset):
self.labels.append(labels)
assert os.path.exists(self.images[-1])
if label_ratio:
return np.array(self.labels).mean(0)
return np.array(self.labels).mean(0).astype("float32")
def __getitem__(self, idx):
try:
......@@ -53,7 +53,7 @@ class MultiLabelDataset(CommonDataset):
img = img.transpose((2, 0, 1))
label = np.array(self.labels[idx]).astype("float32")
if self.label_ratio is not None:
return (img, [label, self.label_ratio])
return (img, np.array([label, self.label_ratio]))
else:
return (img, label)
......
......@@ -84,6 +84,7 @@ def classification_eval(engine, epoch_id=0):
# gather Tensor when distributed
if paddle.distributed.get_world_size() > 1:
label_list = []
paddle.distributed.all_gather(label_list, batch[1])
labels = paddle.concat(label_list, 0)
......
......@@ -38,7 +38,7 @@ class MultiLabelLoss(nn.Layer):
def _binary_crossentropy(self, input, target, class_num):
if self.weight_ratio:
target, label_ratio = target
target, label_ratio = target[:, 0, :], target[:, 1, :]
if self.epsilon is not None:
target = self._labelsmoothing(target, class_num)
cost = F.binary_cross_entropy_with_logits(
......
......@@ -363,6 +363,6 @@ class ATTRMetric(nn.Layer):
self.threshold = threshold
def forward(self, output, target):
metric_dict = get_attr_metrics(target[0].numpy(),
metric_dict = get_attr_metrics(target[:, 0, :].numpy(),
output.numpy(), self.threshold)
return metric_dict
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册