paddlecv.py 2.3 KB
Newer Older
W
wangguanzhong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import importlib
import argparse

__dir__ = os.path.dirname(__file__)

sys.path.insert(0, os.path.join(__dir__, ''))

import cv2
import logging
import numpy as np
from pathlib import Path

ppcv = importlib.import_module('.', 'ppcv')
tools = importlib.import_module('.', 'tools')
tests = importlib.import_module('.', 'tests')

VERSION = '0.1.0'

import yaml
from ppcv.model_zoo.model_zoo import TASK_DICT, list_model, get_config_file
from ppcv.engine.pipeline import Pipeline
from ppcv.utils.logger import setup_logger

logger = setup_logger()


class PaddleCV(object):
    def __init__(self,
                 task_name=None,
                 config_path=None,
                 output_dir=None,
                 run_mode='paddle',
                 device='CPU'):

        if task_name is not None:
            assert task_name in TASK_DICT, f"task_name must be one of {list(TASK_DICT.keys())} but got {task_name}"
            config_path = get_config_file(task_name)
        else:
            assert config_path is not None, "task_name and config_path can not be None at the same time!!!"

        self.cfg_dict = dict(
            config=config_path,
            output_dir=output_dir,
            run_mode=run_mode,
            device=device)
        cfg = argparse.Namespace(**self.cfg_dict)
        self.pipeline = Pipeline(cfg)

    @classmethod
    def list_all_supported_tasks(self, ):
        logger.info(
            f"Tasks and recommanded configs that paddlecv supports are : ")
        buffer = yaml.dump(TASK_DICT)
        print(buffer)
        return

    @classmethod
    def list_all_supported_models(self, filters=[]):
        list_model(filters)
        return

    def __call__(self, input):
        res = self.pipeline.run(input)
        return res