未验证 提交 fa92ccb9 编写于 作者: R Roman Donchenko 提交者: GitHub

SDK: Improve the PyTorch adapter layer (#5455)

* Make the extractors return tensors instead of Python data structures.
* Let the user specify custom label IDs.
上级 f6d2a8fe
......@@ -24,6 +24,7 @@ import appdirs
import attrs
import attrs.validators
import PIL.Image
import torch
import torchvision.datasets
from typing_extensions import TypedDict
......@@ -65,8 +66,7 @@ class Target:
label_id_to_index: Mapping[int, int]
"""
A mapping from label_id values in `LabeledImage` and `LabeledShape` objects
to an index in the range [0, num_labels), where num_labels is the number of labels
defined in the task. This mapping is consistent across all samples for a given task.
to an integer index. This mapping is consistent across all samples for a given task.
"""
......@@ -99,6 +99,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
transforms: Optional[Callable] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
label_name_to_index: Mapping[str, int] = None,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
......@@ -107,6 +108,17 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
`transforms`, `transform` and `target_transforms` are optional transformation
functions; see the documentation for `torchvision.datasets.VisionDataset` for
more information.
`label_name_to_index` affects the `label_id_to_index` member in `Target` objects
returned by the dataset. If it is specified, then it must contain an entry for
each label name in the task. The `label_id_to_index` mapping will be constructed
so that each label will be mapped to the index corresponding to the label's name
in `label_name_to_index`.
If `label_name_to_index` is unspecified or set to `None`, then `label_id_to_index`
will map each label ID to a distinct integer in the range [0, `num_labels`), where
`num_labels` is the number of labels defined in the task. This mapping will be
generally unpredictable, but consistent for a given task.
"""
self._logger = client.logger
......@@ -162,12 +174,19 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
self._logger.info("All chunks downloaded")
self._label_id_to_index = types.MappingProxyType(
{
label["id"]: label_index
for label_index, label in enumerate(sorted(self._task.labels, key=lambda l: l.id))
}
)
if label_name_to_index is None:
self._label_id_to_index = types.MappingProxyType(
{
label.id: label_index
for label_index, label in enumerate(
sorted(self._task.labels, key=lambda l: l.id)
)
}
)
else:
self._label_id_to_index = types.MappingProxyType(
{label.id: label_name_to_index[label.name] for label in self._task.labels}
)
annotations = self._ensure_model(
"annotations.json", LabeledData, self._task.get_annotations, "annotations"
......@@ -283,7 +302,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset):
class ExtractSingleLabelIndex:
"""
A target transform that takes a `Target` object and produces a single label index
based on the tag in that object.
based on the tag in that object, as a 0-dimensional tensor.
This makes the dataset samples compatible with the image classification networks
in torchvision.
......@@ -299,12 +318,12 @@ class ExtractSingleLabelIndex:
if len(tags) > 1:
raise ValueError("sample has multiple tags")
return target.label_id_to_index[tags[0].label_id]
return torch.tensor(target.label_id_to_index[tags[0].label_id], dtype=torch.long)
class LabeledBoxes(TypedDict):
boxes: Sequence[Tuple[float, float, float, float]]
labels: Sequence[int]
boxes: torch.Tensor
labels: torch.Tensor
_SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"])
......@@ -318,9 +337,9 @@ class ExtractBoundingBoxes:
The dictionary contains the following entries:
"boxes": a sequence of (xmin, ymin, xmax, ymax) tuples, one for each shape
in the annotations.
"labels": a sequence of corresponding label indices.
"boxes": a tensor with shape [N, 4], where each row represents a bounding box of a shape
in the annotations in the (xmin, ymin, xmax, ymax) format.
"labels": a tensor with shape [N] containing corresponding label indices.
Limitations:
......@@ -356,4 +375,7 @@ class ExtractBoundingBoxes:
boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords)))
labels.append(target.label_id_to_index[shape.label_id])
return LabeledBoxes(boxes=boxes, labels=labels)
return LabeledBoxes(
boxes=torch.tensor(boxes, dtype=torch.float),
labels=torch.tensor(labels, dtype=torch.long),
)
......@@ -165,8 +165,8 @@ class TestTaskVisionDataset:
target_transform=cvatpt.ExtractSingleLabelIndex(),
)
assert dataset[5][1] == 0
assert dataset[6][1] == 1
assert torch.equal(dataset[5][1], torch.tensor(0))
assert torch.equal(dataset[6][1], torch.tensor(1))
with pytest.raises(ValueError):
# no tags
......@@ -192,9 +192,15 @@ class TestTaskVisionDataset:
target_transform=cvatpt.ExtractBoundingBoxes(include_shape_types={"rectangle"}),
)
assert dataset[0][1] == {"boxes": [], "labels": []}
assert dataset[6][1] == {"boxes": [(1.0, 2.0, 3.0, 4.0)], "labels": [1]}
assert dataset[7][1] == {"boxes": [], "labels": []} # points are filtered out
assert torch.equal(dataset[0][1]["boxes"], torch.tensor([]))
assert torch.equal(dataset[0][1]["labels"], torch.tensor([]))
assert torch.equal(dataset[6][1]["boxes"], torch.tensor([(1.0, 2.0, 3.0, 4.0)]))
assert torch.equal(dataset[6][1]["labels"], torch.tensor([1]))
# points are filtered out
assert torch.equal(dataset[7][1]["boxes"], torch.tensor([]))
assert torch.equal(dataset[7][1]["labels"], torch.tensor([]))
def test_transforms(self):
dataset = cvatpt.TaskVisionDataset(
......@@ -205,3 +211,16 @@ class TestTaskVisionDataset:
assert isinstance(dataset[0][0], cvatpt.Target)
assert isinstance(dataset[0][1], PIL.Image.Image)
def test_custom_label_mapping(self):
label_name_to_id = {label.name: label.id for label in self.task.labels}
dataset = cvatpt.TaskVisionDataset(
self.client,
self.task.id,
label_name_to_index={"person": 123, "car": 456},
)
_, target = dataset[5]
assert target.label_id_to_index[label_name_to_id["person"]] == 123
assert target.label_id_to_index[label_name_to_id["car"]] == 456
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册