提交 d675b020 编写于 作者: W weishengyu

new structure

上级 14348aca
from ..processor import build_processor
class POPEngine:
def __init__(self, config):
self.processor_list = []
last_algo_type = "start"
for processor_config in config["Processors"]:
processor_config["last_algo_type"] = last_algo_type
self.processor_list.append(build_processor(processor_config))
last_algo_type = processor_config["type"]
def process(self, x):
for processor in self.processor_list:
x = processor.process(x)
return x
from ..engine import build_engine
from ..utils import config
def main():
args = config.parse_args()
config_dict = config.get_config(
args.config, overrides=args.override, show=False)
config_dict.profiler_options = args.profiler_options
engine = build_engine(config_dict)
if __name__ == '__main__':
main()
import cv2
class SingleImageTask:
def __init__(self, config, engine):
self.image_path = config.get("image_path")
self.engine = engine
def run(self):
image = cv2.imread(self.image_path)
output = self.engine.process(image)
print(output)
from abc import ABC, abstractmethod
from algo_mod import build_algo_mod
from searcher import build_searcher
from data_processor import build_data_processor
def build_processor(config):
processor_type = config.get("processor_type")
if processor_type == "algo_mod":
return build_algo_mod(config)
elif processor_type == "searcher":
return build_searcher(config)
elif processor_type == "data_processor":
return build_data_processor(config)
else:
raise NotImplemented("processor_type {} not implemented.".format(processor_type))
class BaseProcessor(ABC):
@abstractmethod
def __init__(self, config):
pass
@abstractmethod
def process(self, input_data):
pass
from .fake_cls import FakeClassifier
def build_algo_mod(config):
algo_name = config.get("algo_name")
if algo_name == "fake_clas":
return FakeClassifier(config)
from .. import BaseProcessor
class FakeClassifier(BaseProcessor):
def __init__(self, config):
pass
def process(self, input_data):
pass
class FakeDetector:
def __init__(self):
pass
def predict(self):
pass
from paddle.inference import create_predictor, Config
# from bbox_cropper import
def build_data_processor(config):
return
from .. import BaseProcessor
class BBoxCropper(BaseProcessor):
def __init__(self, config):
pass
def process(self, input_data):
pass
from .. import BaseProcessor
class ImageReader(BaseProcessor):
def __init__(self):
pass
def process(self, input_data):
pass
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import argparse
import yaml
from utils import logger
__all__ = ['get_config']
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
def __deepcopy__(self, content):
return copy.deepcopy(dict(self))
def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
create_attr_dict(yaml_config)
return yaml_config
def print_dict(d, delimiter=0):
"""
Recursively visualize a dict and
indenting acrrording by the relationship of keys.
"""
placeholder = "-" * 60
for k, v in sorted(d.items()):
if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ", k))
print_dict(v, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
logger.info("{}{} : ".format(delimiter * " ", k))
for value in v:
print_dict(value, delimiter + 4)
else:
logger.info("{}{} : {}".format(delimiter * " ", k, v))
if k.isupper():
logger.info(placeholder)
def print_config(config):
"""
visualize configs
Arguments:
config: configs
"""
logger.advertise()
print_dict(config)
def override(dl, ks, v):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
v(str): value to be replaced
"""
def str2num(v):
try:
return eval(v)
except Exception:
return v
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
assert len(ks) > 0, ('lenght of keys should larger than 0')
if isinstance(dl, list):
k = str2num(ks[0])
if len(ks) == 1:
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
dl[k] = str2num(v)
else:
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
# assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
print('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
override(dl[ks[0]], ks[1:], v)
def override_config(config, options=None):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
such as: [
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
if options is not None:
for opt in options:
assert isinstance(opt, str), (
"option({}) should be a str".format(opt))
assert "=" in opt, (
"option({}) should contain a ="
"to distinguish between key and value".format(opt))
pair = opt.split('=')
assert len(pair) == 2, ("there can be only a = in the option")
key, value = pair
keys = key.split('.')
override(config, keys, value)
return config
def get_config(fname, overrides=None, show=False):
"""
Read config from file
"""
assert os.path.exists(fname), (
'config file({}) is not exist'.format(fname))
config = parse_config(fname)
override_config(config, overrides)
if show:
print_config(config)
return config
def parse_args():
parser = argparse.ArgumentParser("generic-image-rec train script")
parser.add_argument(
'-c',
'--config',
type=str,
default='configs/config.yaml',
help='config file path')
parser.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
parser.add_argument(
'-p',
'--profiler_options',
type=str,
default=None,
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
)
parser.add_argument(
'-t',
'--task_config',
type=str,
default="examples/task.yaml",
help='task config file path'
)
args = parser.parse_args()
return args
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import logging
import datetime
import paddle.distributed as dist
_logger = None
def init_logger(name='root', log_file=None, log_level=logging.INFO):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added. If `log_file` is specified a FileHandler will also be added.
Args:
name (str): Logger name.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the logger.
log_level (int): The logger level. Note that only the process of
rank 0 is affected, and other processes will set the level to
"Error" thus be silent most of the time.
Returns:
logging.Logger: The expected logger.
"""
global _logger
assert _logger is None, "logger should not be initialized twice or more."
_logger = logging.getLogger(name)
formatter = logging.Formatter(
'[%(asctime)s] %(name)s %(levelname)s: %(message)s',
datefmt="%Y/%m/%d %H:%M:%S")
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(formatter)
_logger.addHandler(stream_handler)
if log_file is not None and dist.get_rank() == 0:
log_file_folder = os.path.split(log_file)[0]
os.makedirs(log_file_folder, exist_ok=True)
file_handler = logging.FileHandler(log_file, 'a')
file_handler.setFormatter(formatter)
_logger.addHandler(file_handler)
if dist.get_rank() == 0:
_logger.setLevel(log_level)
else:
_logger.setLevel(logging.ERROR)
def log_at_trainer0(log):
"""
logs will print multi-times when calling Fleet API.
Only display single log and ignore the others.
"""
def wrapper(fmt, *args):
if dist.get_rank() == 0:
log(fmt, *args)
return wrapper
@log_at_trainer0
def info(fmt, *args):
_logger.info(fmt, *args)
@log_at_trainer0
def debug(fmt, *args):
_logger.debug(fmt, *args)
@log_at_trainer0
def warning(fmt, *args):
_logger.warning(fmt, *args)
@log_at_trainer0
def error(fmt, *args):
_logger.error(fmt, *args)
def scaler(name, value, step, writer):
"""
This function will draw a scalar curve generated by the visualdl.
Usage: Install visualdl: pip3 install visualdl==2.0.0b4
and then:
visualdl --logdir ./scalar --host 0.0.0.0 --port 8830
to preview loss corve in real time.
"""
if writer is None:
return
writer.add_scalar(tag=name, step=step, value=value)
def advertise():
"""
Show the advertising message like the following:
===========================================================
== PaddleClas is powered by PaddlePaddle ! ==
===========================================================
== ==
== For more info please go to the following website. ==
== ==
== https://github.com/PaddlePaddle/PaddleClas ==
===========================================================
"""
copyright = "PaddleClas is powered by PaddlePaddle !"
ad = "For more info please go to the following website."
website = "https://github.com/PaddlePaddle/PaddleClas"
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
info("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(ad.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册