未验证 提交 24fd19f0 编写于 作者: M Manan Goel 提交者: GitHub

feat(logger): W&B logger with VOC datasets (#1525)

feat(logger): W&B logger with VOC datasets
上级 74b637b4
......@@ -34,6 +34,7 @@ jobs:
pip install -r requirements.txt
pip install isort==4.3.21
pip install flake8==3.8.3
pip install "importlib-metadata<5.0"
# Runs a set of commands using the runners shell
- name: Format check
run: ./.github/workflows/format_check.sh
......@@ -119,6 +119,10 @@ class VOCDetection(Dataset):
self._annopath = os.path.join("%s", "Annotations", "%s.xml")
self._imgpath = os.path.join("%s", "JPEGImages", "%s.jpg")
self._classes = VOC_CLASSES
self.cats = [
{"id": idx, "name": val} for idx, val in enumerate(VOC_CLASSES)
]
self.class_ids = list(range(len(VOC_CLASSES)))
self.ids = list()
for (year, name) in image_sets:
self._year = year
......
......@@ -169,6 +169,8 @@ class WandbLogger(object):
"Please install wandb using pip install wandb"
)
from yolox.data.datasets import VOCDetection
self.project = project
self.name = name
self.id = id
......@@ -202,7 +204,10 @@ class WandbLogger(object):
self.run.define_metric("train/step")
self.run.define_metric("train/*", step_metric="train/step")
self.voc_dataset = VOCDetection
if val_dataset and self.num_log_images != 0:
self.val_dataset = val_dataset
self.cats = val_dataset.cats
self.id_to_class = {
cls['id']: cls['name'] for cls in self.cats
......@@ -241,8 +246,12 @@ class WandbLogger(object):
id = data_point[3]
img = np.transpose(img, (1, 2, 0))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if isinstance(id, torch.Tensor):
id = id.item()
self.val_table.add_data(
id.item(),
id,
self.wandb.Image(img)
)
......@@ -250,6 +259,43 @@ class WandbLogger(object):
self.run.use_artifact(self.val_artifact)
self.val_artifact.wait()
def _convert_prediction_format(self, predictions):
image_wise_data = defaultdict(int)
for key, val in predictions.items():
img_id = key
try:
bboxes, cls, scores = val
except KeyError:
bboxes, cls, scores = val["bboxes"], val["categories"], val["scores"]
# These store information of actual bounding boxes i.e. the ones which are not None
act_box = []
act_scores = []
act_cls = []
if bboxes is not None:
for box, classes, score in zip(bboxes, cls, scores):
if box is None or score is None or classes is None:
continue
act_box.append(box)
act_scores.append(score)
act_cls.append(classes)
image_wise_data.update({
int(img_id): {
"bboxes": [box.numpy().tolist() for box in act_box],
"scores": [score.numpy().item() for score in act_scores],
"categories": [
self.val_dataset.class_ids[int(act_cls[ind])]
for ind in range(len(act_box))
],
}
})
return image_wise_data
def log_metrics(self, metrics, step=None):
"""
Args:
......@@ -277,16 +323,23 @@ class WandbLogger(object):
for cls in self.cats:
columns.append(cls["name"])
if isinstance(self.val_dataset, self.voc_dataset):
predictions = self._convert_prediction_format(predictions)
result_table = self.wandb.Table(columns=columns)
for idx, val in table_ref.iterrows():
avg_scores = defaultdict(int)
num_occurrences = defaultdict(int)
if val[0] in predictions:
prediction = predictions[val[0]]
boxes = []
id = val[0]
if isinstance(id, list):
id = id[0]
if id in predictions:
prediction = predictions[id]
boxes = []
for i in range(len(prediction["bboxes"])):
bbox = prediction["bboxes"][i]
x0 = bbox[0]
......@@ -310,7 +363,6 @@ class WandbLogger(object):
boxes.append(box)
else:
boxes = []
average_class_score = []
for cls in self.cats:
if cls["name"] not in num_occurrences:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册