reader_utils.py 2.3 KB
Newer Older
D
dengkaipeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#  Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

15 16 17 18 19 20 21
import pickle
import cv2
import numpy as np
import random


class ReaderNotFoundError(Exception):
D
dengkaipeng 已提交
22
    "Error: reader not found"
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39

    def __init__(self, reader_name, avail_readers):
        super(ReaderNotFoundError, self).__init__()
        self.reader_name = reader_name
        self.avail_readers = avail_readers

    def __str__(self):
        msg = "Reader {} Not Found.\nAvailiable readers:\n".format(
            self.reader_name)
        for reader in self.avail_readers:
            msg += "  {}\n".format(reader)
        return msg


class DataReader(object):
    """data reader for video input"""

40
    def __init__(self, model_name, mode, cfg):
41 42 43
        self.name = model_name
        self.mode = mode
        self.cfg = cfg
44 45 46 47 48

    def create_reader(self):
        """Not implemented"""
        pass

49 50 51 52 53 54
    def get_config_from_sec(self, sec, item, default=None):
        if sec.upper() not in self.cfg:
            return default
        return self.cfg[sec.upper()].get(item, default)


55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79

class ReaderZoo(object):
    def __init__(self):
        self.reader_zoo = {}

    def regist(self, name, reader):
        assert reader.__base__ == DataReader, "Unknow model type {}".format(
            type(reader))
        self.reader_zoo[name] = reader

    def get(self, name, mode, cfg):
        for k, v in self.reader_zoo.items():
            if k == name:
                return v(name, mode, cfg)
        raise ReaderNotFoundError(name, self.reader_zoo.keys())


# singleton reader_zoo
reader_zoo = ReaderZoo()


def regist_reader(name, reader):
    reader_zoo.regist(name, reader)


80
def get_reader(name, mode, cfg):
81 82
    reader_model = reader_zoo.get(name, mode, cfg)
    return reader_model.create_reader()