未验证 提交 ad534b2a 编写于 作者: M Maxim Zhiltsov 提交者: GitHub

Refactor server api tests and dataset_manifest (#5997)

Extracted from #5083 

- Fixed test data placement in tests
- Fixed cache cleaning in tests
- Refactored server API tests and dataset_manifest
- Added more tests
上级 31c1ecc4
......@@ -2,14 +2,115 @@
#
# SPDX-License-Identifier: MIT
import itertools
from contextlib import contextmanager
from io import BytesIO
from typing import Callable, Iterator, TypeVar
import itertools
import logging
import os
from django.core.cache import caches
from django.http.response import HttpResponse
from PIL import Image
from rest_framework.test import APIClient, APITestCase
import av
import numpy as np
T = TypeVar('T')
@contextmanager
def logging_disabled():
old_level = logging.getLogger().manager.disable
try:
logging.disable(logging.CRITICAL)
yield
finally:
logging.disable(old_level)
class ForceLogin:
def __init__(self, user, client):
self.user = user
self.client = client
def __enter__(self):
if self.user:
self.client.force_login(self.user, backend='django.contrib.auth.backends.ModelBackend')
return self
def __exit__(self, exception_type, exception_value, traceback):
if self.user:
self.client.logout()
class ApiTestBase(APITestCase):
def setUp(self):
super().setUp()
self.client = APIClient()
def tearDown(self):
# Clear server frame/chunk cache.
# The parent class clears DB changes, and it can lead to under-cleaned task data,
# which can affect other tests.
# This situation is not expected to happen on a real server, because
# cache keys include Data object ids, which cannot be reused or freed
# in real scenarios
for cache in caches.all(initialized_only=True):
cache.clear()
return super().tearDown()
def generate_image_file(filename, size=(100, 100)):
assert os.path.splitext(filename)[-1].lower() in ['', '.jpg', '.jpeg'], \
"This function supports only jpeg images. Please add the .jpg extension to the file name"
f = BytesIO()
image = Image.new('RGB', size=size)
image.save(f, 'jpeg')
f.name = filename
f.seek(0)
return f
def generate_video_file(filename, width=1920, height=1080, duration=1, fps=25, codec_name='mpeg4'):
f = BytesIO()
total_frames = duration * fps
file_ext = os.path.splitext(filename)[1][1:]
container = av.open(f, mode='w', format=file_ext)
stream = container.add_stream(codec_name=codec_name, rate=fps)
stream.width = width
stream.height = height
stream.pix_fmt = 'yuv420p'
for frame_i in range(total_frames):
img = np.empty((stream.width, stream.height, 3))
img[:, :, 0] = 0.5 + 0.5 * np.sin(2 * np.pi * (0 / 3 + frame_i / total_frames))
img[:, :, 1] = 0.5 + 0.5 * np.sin(2 * np.pi * (1 / 3 + frame_i / total_frames))
img[:, :, 2] = 0.5 + 0.5 * np.sin(2 * np.pi * (2 / 3 + frame_i / total_frames))
img = np.round(255 * img).astype(np.uint8)
img = np.clip(img, 0, 255)
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
for packet in stream.encode(frame):
container.mux(packet)
# Flush stream
for packet in stream.encode():
container.mux(packet)
# Close the file
container.close()
f.name = filename
f.seek(0)
return [(width, height)] * total_frames, f
def get_paginated_collection(
request_chunk_callback: Callable[[int], HttpResponse]
) -> Iterator[T]:
......
......@@ -392,9 +392,10 @@ STATIC_URL = '/static/'
STATIC_ROOT = os.path.join(BASE_DIR, 'static')
os.makedirs(STATIC_ROOT, exist_ok=True)
# Make sure to update other config files when upading these directories
# Make sure to update other config files when updating these directories
DATA_ROOT = os.path.join(BASE_DIR, 'data')
EVENTS_LOCAL_DB = os.path.join(DATA_ROOT,'events.db')
EVENTS_LOCAL_DB = os.path.join(DATA_ROOT, 'events.db')
os.makedirs(DATA_ROOT, exist_ok=True)
if not os.path.exists(EVENTS_LOCAL_DB):
open(EVENTS_LOCAL_DB, 'w').close()
......
......@@ -18,7 +18,7 @@ BASE_DIR = _temp_dir.name
DATA_ROOT = os.path.join(BASE_DIR, 'data')
os.makedirs(DATA_ROOT, exist_ok=True)
EVENTS_LOCAL_DB = os.path.join(DATA_ROOT,'logstash.db')
EVENTS_LOCAL_DB = os.path.join(DATA_ROOT, 'events.db')
os.makedirs(DATA_ROOT, exist_ok=True)
if not os.path.exists(EVENTS_LOCAL_DB):
open(EVENTS_LOCAL_DB, 'w').close()
......@@ -67,6 +67,8 @@ for logger in LOGGING["loggers"].values():
LOGGING["handlers"]["server_file"] = LOGGING["handlers"]["console"]
CACHES["media"]["LOCATION"] = CACHE_ROOT
PASSWORD_HASHERS = (
'django.contrib.auth.hashers.MD5PasswordHasher',
)
......
# Copyright (C) 2021-2022 Intel Corporation
# Copyright (C) 2022-2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from enum import Enum
from io import StringIO
import av
import json
import os
from abc import ABC, abstractmethod, abstractproperty, abstractstaticmethod
from contextlib import closing
from tempfile import NamedTemporaryFile
from PIL import Image
from json.decoder import JSONDecodeError
from .errors import InvalidManifestError, InvalidVideoFrameError
from .utils import SortingMethod, md5_hash, rotate_image, sort
class VideoStreamReader:
......@@ -24,14 +26,11 @@ class VideoStreamReader:
with closing(av.open(self.source_path, mode='r')) as container:
video_stream = VideoStreamReader._get_video_stream(container)
isBreaked = False
for packet in container.demux(video_stream):
if isBreaked:
break
for frame in packet.decode():
# check type of first frame
if not frame.pict_type.name == 'I':
raise Exception('First frame is not key frame')
raise InvalidVideoFrameError('First frame is not key frame')
# get video resolution
if video_stream.metadata.get('rotate'):
......@@ -43,11 +42,12 @@ class VideoStreamReader:
format ='bgr24',
)
self.height, self.width = (frame.height, frame.width)
# not all videos contain information about numbers of frames
if video_stream.frames:
self._frames_number = video_stream.frames
isBreaked = True
break
return
@property
def source_path(self):
......@@ -81,9 +81,9 @@ class VideoStreamReader:
for packet in container.demux(video_stream):
for frame in packet.decode():
if None not in {frame.pts, frame_pts} and frame.pts <= frame_pts:
raise Exception('Invalid pts sequences')
raise InvalidVideoFrameError('Invalid pts sequences')
if None not in {frame.dts, frame_dts} and frame.dts <= frame_dts:
raise Exception('Invalid dts sequences')
raise InvalidVideoFrameError('Invalid dts sequences')
frame_pts, frame_dts = frame.pts, frame.dts
if frame.key_frame:
......@@ -122,9 +122,9 @@ class KeyFramesVideoStreamReader(VideoStreamReader):
for packet in container.demux(video_stream):
for frame in packet.decode():
if None not in {frame.pts, frame_pts} and frame.pts <= frame_pts:
raise Exception('Invalid pts sequences')
raise InvalidVideoFrameError('Invalid pts sequences')
if None not in {frame.dts, frame_dts} and frame.dts <= frame_dts:
raise Exception('Invalid dts sequences')
raise InvalidVideoFrameError('Invalid dts sequences')
frame_pts, frame_dts = frame.pts, frame.dts
if frame.key_frame:
......@@ -148,13 +148,13 @@ class KeyFramesVideoStreamReader(VideoStreamReader):
class DatasetImagesReader:
def __init__(self,
sources,
meta=None,
sorting_method=SortingMethod.PREDEFINED,
use_image_hash=False,
*,
start = 0,
step = 1,
stop = None,
*args,
meta=None,
sorting_method=SortingMethod.PREDEFINED,
use_image_hash=False,
**kwargs):
self._sources = sort(sources, sorting_method)
self._meta = meta
......@@ -194,23 +194,28 @@ class DatasetImagesReader:
if idx in self.range_:
image = next(sources)
img = Image.open(image, mode='r')
orientation = img.getexif().get(274, 1)
img_name = os.path.relpath(image, self._data_dir) if self._data_dir \
else os.path.basename(image)
name, extension = os.path.splitext(img_name)
width, height = img.width, img.height
if orientation > 4:
width, height = height, width
image_properties = {
'name': name.replace('\\', '/'),
'extension': extension,
'width': width,
'height': height,
}
width, height = img.width, img.height
orientation = img.getexif().get(274, 1)
if orientation > 4:
width, height = height, width
image_properties['width'] = width
image_properties['height'] = height
if self._meta and img_name in self._meta:
image_properties['meta'] = self._meta[img_name]
if self._use_image_hash:
image_properties['checksum'] = md5_hash(img)
yield image_properties
else:
yield dict()
......@@ -258,6 +263,7 @@ class _Manifest:
FILE_NAME = 'manifest.jsonl'
VERSION = SupportedVersion.V1_1
TYPE: str # must be set externally
def __init__(self, path, upload_dir=None):
assert path, 'A path to manifest file not found'
......@@ -273,6 +279,13 @@ class _Manifest:
return os.path.basename(self._path) if not self._upload_dir \
else os.path.relpath(self._path, self._upload_dir)
def get_header_lines_count(self) -> int:
if self.TYPE == 'video':
return 3
elif self.TYPE == 'images':
return 2
assert False, f"Unknown manifest type '{self.TYPE}'"
# Needed for faster iteration over the manifest file, will be generated to work inside CVAT
# and will not be generated when manually creating a manifest
class _Index:
......@@ -299,7 +312,7 @@ class _Index:
def remove(self):
os.remove(self._path)
def create(self, manifest, skip):
def create(self, manifest, *, skip):
assert os.path.exists(manifest), 'A manifest file not exists, index cannot be created'
with open(manifest, 'r+') as manifest_file:
while skip:
......@@ -327,20 +340,14 @@ class _Index:
line = manifest_file.readline()
def __getitem__(self, number):
assert 0 <= number < len(self), \
'Invalid index number: {}\nMax: {}'.format(number, len(self) - 1)
if not 0 <= number < len(self):
raise IndexError('Invalid index number: {}\nMax: {}'.format(number, len(self) - 1))
return self._index[number]
def __len__(self):
return len(self._index)
def _set_index(func):
def wrapper(self, *args, **kwargs):
func(self, *args, **kwargs)
if self._create_index:
self.set_index()
return wrapper
class _ManifestManager(ABC):
BASE_INFORMATION = {
'version' : 1,
......@@ -348,11 +355,14 @@ class _ManifestManager(ABC):
}
def _json_item_is_valid(self, **state):
for item in self._requared_item_attributes:
for item in self._required_item_attributes:
if state.get(item, None) is None:
raise Exception(f"Invalid '{self.manifest.name} file structure': '{item}' is required, but not found")
raise InvalidManifestError(
f"Invalid '{self.manifest.name}' file structure: "
f"'{item}' is required, but not found"
)
def __init__(self, path, create_index, upload_dir=None, *args, **kwargs):
def __init__(self, path, create_index, upload_dir=None):
self._manifest = _Manifest(path, upload_dir)
self._index = _Index(os.path.dirname(self._manifest.path))
self._reader = None
......@@ -384,11 +394,12 @@ class _ManifestManager(ABC):
if os.path.exists(self._index.path):
self._index.load()
else:
self._index.create(self._manifest.path, 3 if self._manifest.TYPE == 'video' else 2)
self._index.dump()
self._index.create(self._manifest.path, skip=self._manifest.get_header_lines_count())
if self._create_index:
self._index.dump()
def reset_index(self):
if os.path.exists(self._index.path):
if self._create_index and os.path.exists(self._index.path):
self._index.remove()
def set_index(self):
......@@ -402,24 +413,23 @@ class _ManifestManager(ABC):
@abstractmethod
def create(self, content=None, _tqdm=None):
pass
...
@abstractmethod
def partial_update(self, number, properties):
pass
...
def __iter__(self):
self.set_index()
with open(self._manifest.path, 'r') as manifest_file:
manifest_file.seek(self._index[0])
image_number = 0
line = manifest_file.readline()
while line:
if line.strip():
parsed_properties = json.loads(line)
self._json_item_is_valid(**parsed_properties)
yield (image_number, parsed_properties)
image_number += 1
for idx, line_start in enumerate(self._index):
manifest_file.seek(line_start)
line = manifest_file.readline()
item = json.loads(line)
self._json_item_is_valid(**item)
yield (idx, item)
@property
def manifest(self):
......@@ -440,14 +450,14 @@ class _ManifestManager(ABC):
@abstractproperty
def data(self):
pass
...
@abstractmethod
def get_subset(self, subset_names):
pass
...
class VideoManifestManager(_ManifestManager):
_requared_item_attributes = {'number', 'pts'}
_required_item_attributes = {'number', 'pts'}
def __init__(self, manifest_path, create_index=True):
super().__init__(manifest_path, create_index)
......@@ -487,24 +497,22 @@ class VideoManifestManager(_ManifestManager):
}, separators=(',', ':'))
file.write(f"{json_item}\n")
# pylint: disable=arguments-differ
@_set_index
def create(self, _tqdm=None):
def create(self, *, _tqdm=None): # pylint: disable=arguments-differ
""" Creating and saving a manifest file """
if not len(self._reader):
with NamedTemporaryFile(mode='w', delete=False)as tmp_file:
self._write_core_part(tmp_file, _tqdm)
temp = tmp_file.name
tmp_file = StringIO()
self._write_core_part(tmp_file, _tqdm)
with open(self._manifest.path, 'w') as manifest_file:
self._write_base_information(manifest_file)
with open(temp, 'r') as tmp_file:
manifest_file.write(tmp_file.read())
os.remove(temp)
manifest_file.write(tmp_file.getvalue())
else:
with open(self._manifest.path, 'w') as manifest_file:
self._write_base_information(manifest_file)
self._write_core_part(manifest_file, _tqdm)
self.set_index()
def partial_update(self, number, properties):
pass
......@@ -567,7 +575,7 @@ class VideoManifestValidator(VideoManifestManager):
return
class ImageManifestManager(_ManifestManager):
_requared_item_attributes = {'name', 'extension'}
_required_item_attributes = {'name', 'extension'}
def __init__(self, manifest_path, upload_dir=None, create_index=True):
super().__init__(manifest_path, create_index, upload_dir)
......@@ -596,7 +604,6 @@ class ImageManifestManager(_ManifestManager):
}, separators=(',', ':'))
file.write(f"{json_line}\n")
@_set_index
def create(self, content=None, _tqdm=None):
""" Creating and saving a manifest file for the specialized dataset"""
with open(self._manifest.path, 'w') as manifest_file:
......@@ -604,6 +611,8 @@ class ImageManifestManager(_ManifestManager):
obj = content if content else self._reader
self._write_core_part(manifest_file, obj, _tqdm)
self.set_index()
def partial_update(self, number, properties):
pass
......@@ -644,17 +653,17 @@ class _BaseManifestValidator(ABC):
line = json.loads(manifest.readline().strip())
validator(line)
return True
except (ValueError, KeyError, JSONDecodeError):
except (ValueError, KeyError, JSONDecodeError, InvalidManifestError):
return False
@staticmethod
def _validate_version(_dict):
if not _dict['version'] in _Manifest.SupportedVersion.choices():
raise ValueError('Incorrect version field')
raise InvalidManifestError('Incorrect version field')
def _validate_type(self, _dict):
if not _dict['type'] == self.TYPE:
raise ValueError('Incorrect type field')
raise InvalidManifestError('Incorrect type field')
@abstractproperty
def validators(self):
......@@ -680,18 +689,18 @@ class _VideoManifestStructureValidator(_BaseManifestValidator):
def _validate_properties(_dict):
properties = _dict['properties']
if not isinstance(properties['name'], str):
raise ValueError('Incorrect name field')
raise InvalidManifestError('Incorrect name field')
if not isinstance(properties['resolution'], list):
raise ValueError('Incorrect resolution field')
raise InvalidManifestError('Incorrect resolution field')
if not isinstance(properties['length'], int) or properties['length'] == 0:
raise ValueError('Incorrect length field')
raise InvalidManifestError('Incorrect length field')
@staticmethod
def _validate_first_item(_dict):
if not isinstance(_dict['number'], int):
raise ValueError('Incorrect number field')
raise InvalidManifestError('Incorrect number field')
if not isinstance(_dict['pts'], int):
raise ValueError('Incorrect pts field')
raise InvalidManifestError('Incorrect pts field')
class _DatasetManifestStructureValidator(_BaseManifestValidator):
TYPE = 'images'
......@@ -707,18 +716,18 @@ class _DatasetManifestStructureValidator(_BaseManifestValidator):
@staticmethod
def _validate_first_item(_dict):
if not isinstance(_dict['name'], str):
raise ValueError('Incorrect name field')
raise InvalidManifestError('Incorrect name field')
if not isinstance(_dict['extension'], str):
raise ValueError('Incorrect extension field')
raise InvalidManifestError('Incorrect extension field')
# FIXME
# Width and height are required for 2D data, but
# for 3D these parameters are not saved now.
# It is necessary to uncomment these restrictions when manual preparation for 3D data is implemented.
# if not isinstance(_dict['width'], int):
# raise ValueError('Incorrect width field')
# raise InvalidManifestError('Incorrect width field')
# if not isinstance(_dict['height'], int):
# raise ValueError('Incorrect height field')
# raise InvalidManifestError('Incorrect height field')
def is_manifest(full_manifest_path):
return _is_video_manifest(full_manifest_path) or \
......
# Copyright (C) 2023 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
class BasicError(Exception):
"""
The basic exception type for all exceptions in the library
"""
class InvalidVideoFrameError(BasicError):
"""
Indicates an invalid video frame
"""
class InvalidManifestError(BasicError):
"""
Indicates an invalid manifest
"""
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册