提交 c356974e 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add inference and its yaml

上级 97cd5540
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import importlib
import numpy as np
import paddle
import paddle.nn.functional as F
def build_postprocess(config):
config = copy.deepcopy(config)
model_name = config.pop("name")
mod = importlib.import_module(__name__)
postprocess_func = getattr(mod, model_name)(**config)
return postprocess_func
class Topk(object):
def __init__(self, topk=1, class_id_map_file=None):
assert isinstance(topk, (int, ))
self.class_id_map = self.parse_class_id_map(class_id_map_file)
self.topk = topk
def parse_class_id_map(self, class_id_map_file):
if class_id_map_file is None:
return None
if not os.path.exists(class_id_map_file):
print(
"Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!"
)
return None
try:
class_id_map = {}
with open(class_id_map_file, "r") as fin:
lines = fin.readlines()
for line in lines:
partition = line.split("\n")[0].partition(" ")
class_id_map[int(partition[0])] = str(partition[-1])
except Exception as ex:
print(ex)
class_id_map = None
return class_id_map
def __call__(self, x, file_names=None):
if file_names is not None:
assert x.shape[0] == len(file_names)
y = []
for idx, probs in enumerate(x):
index = probs.argsort(axis=0)[-self.topk:][::-1].astype("int32")
clas_id_list = []
score_list = []
label_name_list = []
for i in index:
clas_id_list.append(i.item())
score_list.append(probs[i].item())
if self.class_id_map is not None:
label_name_list.append(self.class_id_map[i.item()])
result = {
"class_ids": clas_id_list,
"scores": np.around(
score_list, decimals=5).tolist(),
}
if file_names is not None:
result["file_name"] = file_names[idx]
if label_name_list is not None:
result["label_names"] = label_name_list
y.append(result)
return y
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))
import cv2
import numpy as np
from utils import logger
from utils import config
from utils.predictor import Predictor
from utils.get_image_list import get_image_list
from preprocess import create_operators
from postprocess import build_postprocess
class ClsPredictor(object):
def __init__(self, config):
super().__init__()
self.predictor = Predictor(config["Global"])
self.preprocess_ops = create_operators(config["PreProcess"][
"transform_ops"])
self.postprocess = build_postprocess(config["PostProcess"])
def main(config):
cls_predictor = ClsPredictor(config)
image_list = get_image_list(config["Global"]["infer_imgs"])
assert config["Global"]["batch_size"] == 1
for idx, image_file in enumerate(image_list):
batch_input = []
img = cv2.imread(image_file)[:, :, ::-1]
for ops in cls_predictor.preprocess_ops:
img = ops(img)
batch_input.append(img)
output = cls_predictor.predictor.predict(np.array(batch_input))
output = cls_predictor.postprocess(output)
print(output)
return
if __name__ == "__main__":
args = config.parse_args()
config = config.get_config(args.config, overrides=args.override, show=True)
main(config)
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import six
import math
import random
import cv2
import numpy as np
import importlib
def create_operators(params):
"""
create operators based on the config
Args:
params(list): a dict list, used to create some operators
"""
assert isinstance(params, list), ('operator config should be a list')
mod = importlib.import_module(__name__)
ops = []
for operator in params:
assert isinstance(operator,
dict) and len(operator) == 1, "yaml format error"
op_name = list(operator)[0]
param = {} if operator[op_name] is None else operator[op_name]
op = getattr(mod, op_name)(**param)
ops.append(op)
return ops
class OperatorParamError(ValueError):
""" OperatorParamError
"""
pass
class DecodeImage(object):
""" decode image """
def __init__(self, to_rgb=True, to_np=False, channel_first=False):
self.to_rgb = to_rgb
self.to_np = to_np # to numpy
self.channel_first = channel_first # only enabled when to_np is True
def __call__(self, img):
if six.PY2:
assert type(img) is str and len(
img) > 0, "invalid input 'img' in DecodeImage"
else:
assert type(img) is bytes and len(
img) > 0, "invalid input 'img' in DecodeImage"
data = np.frombuffer(img, dtype='uint8')
img = cv2.imdecode(data, 1)
if self.to_rgb:
assert img.shape[2] == 3, 'invalid shape of image[%s]' % (
img.shape)
img = img[:, :, ::-1]
if self.channel_first:
img = img.transpose((2, 0, 1))
return img
class ResizeImage(object):
""" resize image """
def __init__(self, size=None, resize_short=None, interpolation=-1):
self.interpolation = interpolation if interpolation >= 0 else None
if resize_short is not None and resize_short > 0:
self.resize_short = resize_short
self.w = None
self.h = None
elif size is not None:
self.resize_short = None
self.w = size if type(size) is int else size[0]
self.h = size if type(size) is int else size[1]
else:
raise OperatorParamError("invalid params for ReisizeImage for '\
'both 'size' and 'resize_short' are None")
def __call__(self, img):
img_h, img_w = img.shape[:2]
if self.resize_short is not None:
percent = float(self.resize_short) / min(img_w, img_h)
w = int(round(img_w * percent))
h = int(round(img_h * percent))
else:
w = self.w
h = self.h
if self.interpolation is None:
return cv2.resize(img, (w, h))
else:
return cv2.resize(img, (w, h), interpolation=self.interpolation)
class CropImage(object):
""" crop image """
def __init__(self, size):
if type(size) is int:
self.size = (size, size)
else:
self.size = size # (h, w)
def __call__(self, img):
w, h = self.size
img_h, img_w = img.shape[:2]
w_start = (img_w - w) // 2
h_start = (img_h - h) // 2
w_end = w_start + w
h_end = h_start + h
return img[h_start:h_end, w_start:w_end, :]
class RandCropImage(object):
""" random crop image """
def __init__(self, size, scale=None, ratio=None, interpolation=-1):
self.interpolation = interpolation if interpolation >= 0 else None
if type(size) is int:
self.size = (size, size) # (h, w)
else:
self.size = size
self.scale = [0.08, 1.0] if scale is None else scale
self.ratio = [3. / 4., 4. / 3.] if ratio is None else ratio
def __call__(self, img):
size = self.size
scale = self.scale
ratio = self.ratio
aspect_ratio = math.sqrt(random.uniform(*ratio))
w = 1. * aspect_ratio
h = 1. / aspect_ratio
img_h, img_w = img.shape[:2]
bound = min((float(img_w) / img_h) / (w**2),
(float(img_h) / img_w) / (h**2))
scale_max = min(scale[1], bound)
scale_min = min(scale[0], bound)
target_area = img_w * img_h * random.uniform(scale_min, scale_max)
target_size = math.sqrt(target_area)
w = int(target_size * w)
h = int(target_size * h)
i = random.randint(0, img_w - w)
j = random.randint(0, img_h - h)
img = img[j:j + h, i:i + w, :]
if self.interpolation is None:
return cv2.resize(img, size)
else:
return cv2.resize(img, size, interpolation=self.interpolation)
class RandFlipImage(object):
""" random flip image
flip_code:
1: Flipped Horizontally
0: Flipped Vertically
-1: Flipped Horizontally & Vertically
"""
def __init__(self, flip_code=1):
assert flip_code in [-1, 0, 1
], "flip_code should be a value in [-1, 0, 1]"
self.flip_code = flip_code
def __call__(self, img):
if random.randint(0, 1) == 1:
return cv2.flip(img, self.flip_code)
else:
return img
class AutoAugment(object):
def __init__(self):
self.policy = ImageNetPolicy()
def __call__(self, img):
from PIL import Image
img = np.ascontiguousarray(img)
img = Image.fromarray(img)
img = self.policy(img)
img = np.asarray(img)
class NormalizeImage(object):
""" normalize image such as substract mean, divide std
"""
def __init__(self,
scale=None,
mean=None,
std=None,
order='chw',
output_fp16=False,
channel_num=3):
if isinstance(scale, str):
scale = eval(scale)
assert channel_num in [
3, 4
], "channel number of input image should be set to 3 or 4."
self.channel_num = channel_num
self.output_dtype = 'float16' if output_fp16 else 'float32'
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
self.order = order
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (3, 1, 1) if self.order == 'chw' else (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, img):
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
assert isinstance(img,
np.ndarray), "invalid input 'img' in NormalizeImage"
img = (img.astype('float32') * self.scale - self.mean) / self.std
if self.channel_num == 4:
img_h = img.shape[1] if self.order == 'chw' else img.shape[0]
img_w = img.shape[2] if self.order == 'chw' else img.shape[1]
pad_zeros = np.zeros(
(1, img_h, img_w)) if self.order == 'chw' else np.zeros(
(img_h, img_w, 1))
img = (np.concatenate(
(img, pad_zeros), axis=0)
if self.order == 'chw' else np.concatenate(
(img, pad_zeros), axis=2))
return img.astype(self.output_dtype)
class ToCHWImage(object):
""" convert hwc image to chw image
"""
def __init__(self):
pass
def __call__(self, img):
from PIL import Image
if isinstance(img, Image.Image):
img = np.array(img)
return img.transpose((2, 0, 1))
from . import logger
from . import config
from . import get_image_list
from . import predictor
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import argparse
import yaml
from utils import logger
__all__ = ['get_config']
class AttrDict(dict):
def __getattr__(self, key):
return self[key]
def __setattr__(self, key, value):
if key in self.__dict__:
self.__dict__[key] = value
else:
self[key] = value
def __deepcopy__(self, content):
return copy.deepcopy(dict(self))
def create_attr_dict(yaml_config):
from ast import literal_eval
for key, value in yaml_config.items():
if type(value) is dict:
yaml_config[key] = value = AttrDict(value)
if isinstance(value, str):
try:
value = literal_eval(value)
except BaseException:
pass
if isinstance(value, AttrDict):
create_attr_dict(yaml_config[key])
else:
yaml_config[key] = value
def parse_config(cfg_file):
"""Load a config file into AttrDict"""
with open(cfg_file, 'r') as fopen:
yaml_config = AttrDict(yaml.load(fopen, Loader=yaml.SafeLoader))
create_attr_dict(yaml_config)
return yaml_config
def print_dict(d, delimiter=0):
"""
Recursively visualize a dict and
indenting acrrording by the relationship of keys.
"""
placeholder = "-" * 60
for k, v in sorted(d.items()):
if isinstance(v, dict):
logger.info("{}{} : ".format(delimiter * " ",
logger.coloring(k, "HEADER")))
print_dict(v, delimiter + 4)
elif isinstance(v, list) and len(v) >= 1 and isinstance(v[0], dict):
logger.info("{}{} : ".format(delimiter * " ",
logger.coloring(str(k), "HEADER")))
for value in v:
print_dict(value, delimiter + 4)
else:
logger.info("{}{} : {}".format(delimiter * " ",
logger.coloring(k, "HEADER"),
logger.coloring(v, "OKGREEN")))
if k.isupper():
logger.info(placeholder)
def print_config(config):
"""
visualize configs
Arguments:
config: configs
"""
logger.advertise()
print_dict(config)
def override(dl, ks, v):
"""
Recursively replace dict of list
Args:
dl(dict or list): dict or list to be replaced
ks(list): list of keys
v(str): value to be replaced
"""
def str2num(v):
try:
return eval(v)
except Exception:
return v
assert isinstance(dl, (list, dict)), ("{} should be a list or a dict")
assert len(ks) > 0, ('lenght of keys should larger than 0')
if isinstance(dl, list):
k = str2num(ks[0])
if len(ks) == 1:
assert k < len(dl), ('index({}) out of range({})'.format(k, dl))
dl[k] = str2num(v)
else:
override(dl[k], ks[1:], v)
else:
if len(ks) == 1:
# assert ks[0] in dl, ('{} is not exist in {}'.format(ks[0], dl))
if not ks[0] in dl:
logger.warning('A new filed ({}) detected!'.format(ks[0], dl))
dl[ks[0]] = str2num(v)
else:
override(dl[ks[0]], ks[1:], v)
def override_config(config, options=None):
"""
Recursively override the config
Args:
config(dict): dict to be replaced
options(list): list of pairs(key0.key1.idx.key2=value)
such as: [
'topk=2',
'VALID.transforms.1.ResizeImage.resize_short=300'
]
Returns:
config(dict): replaced config
"""
if options is not None:
for opt in options:
assert isinstance(opt, str), (
"option({}) should be a str".format(opt))
assert "=" in opt, (
"option({}) should contain a ="
"to distinguish between key and value".format(opt))
pair = opt.split('=')
assert len(pair) == 2, ("there can be only a = in the option")
key, value = pair
keys = key.split('.')
override(config, keys, value)
return config
def get_config(fname, overrides=None, show=True):
"""
Read config from file
"""
assert os.path.exists(fname), (
'config file({}) is not exist'.format(fname))
config = parse_config(fname)
override_config(config, overrides)
if show:
print_config(config)
# check_config(config)
return config
def parse_args():
parser = argparse.ArgumentParser("generic-image-rec train script")
parser.add_argument(
'-c',
'--config',
type=str,
default='configs/config.yaml',
help='config file path')
parser.add_argument(
'-o',
'--override',
action='append',
default=[],
help='config options to be overridden')
args = parser.parse_args()
return args
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import base64
import numpy as np
def get_image_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
if single_file.split('.')[-1] in img_end:
imgs_lists.append(os.path.join(img_file, single_file))
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
imgs_lists = sorted(imgs_lists)
return imgs_lists
def get_image_list_from_label_file(image_path, label_file_path):
imgs_lists = []
gt_labels = []
with open(label_file_path, "r") as fin:
lines = fin.readlines()
for line in lines:
image_name, label = line.strip("\n").split()
label = int(label)
imgs_lists.append(os.path.join(image_path, image_name))
gt_labels.append(int(label))
return imgs_lists, gt_labels
# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import datetime
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S")
def time_zone(sec, fmt):
real_time = datetime.datetime.now()
return real_time.timetuple()
logging.Formatter.converter = time_zone
_logger = logging.getLogger(__name__)
Color = {
'RED': '\033[31m',
'HEADER': '\033[35m', # deep purple
'PURPLE': '\033[95m', # purple
'OKBLUE': '\033[94m',
'OKGREEN': '\033[92m',
'WARNING': '\033[93m',
'FAIL': '\033[91m',
'ENDC': '\033[0m'
}
def coloring(message, color="OKGREEN"):
assert color in Color.keys()
if os.environ.get('PADDLECLAS_COLORING', False):
return Color[color] + str(message) + Color["ENDC"]
else:
return message
def anti_fleet(log):
"""
logs will print multi-times when calling Fleet API.
Only display single log and ignore the others.
"""
def wrapper(fmt, *args):
if int(os.getenv("PADDLE_TRAINER_ID", 0)) == 0:
log(fmt, *args)
return wrapper
@anti_fleet
def info(fmt, *args):
_logger.info(fmt, *args)
@anti_fleet
def warning(fmt, *args):
_logger.warning(coloring(fmt, "RED"), *args)
@anti_fleet
def error(fmt, *args):
_logger.error(coloring(fmt, "FAIL"), *args)
def scaler(name, value, step, writer):
"""
This function will draw a scalar curve generated by the visualdl.
Usage: Install visualdl: pip3 install visualdl==2.0.0b4
and then:
visualdl --logdir ./scalar --host 0.0.0.0 --port 8830
to preview loss corve in real time.
"""
writer.add_scalar(tag=name, step=step, value=value)
def advertise():
"""
Show the advertising message like the following:
===========================================================
== PaddleClas is powered by PaddlePaddle ! ==
===========================================================
== ==
== For more info please go to the following website. ==
== ==
== https://github.com/PaddlePaddle/PaddleClas ==
===========================================================
"""
copyright = "PaddleClas is powered by PaddlePaddle !"
ad = "For more info please go to the following website."
website = "https://github.com/PaddlePaddle/PaddleClas"
AD_LEN = 6 + len(max([copyright, ad, website], key=len))
info(
coloring("\n{0}\n{1}\n{2}\n{3}\n{4}\n{5}\n{6}\n{7}\n".format(
"=" * (AD_LEN + 4),
"=={}==".format(copyright.center(AD_LEN)),
"=" * (AD_LEN + 4),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(ad.center(AD_LEN)),
"=={}==".format(' ' * AD_LEN),
"=={}==".format(website.center(AD_LEN)),
"=" * (AD_LEN + 4), ), "RED"))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import base64
import shutil
import cv2
import numpy as np
from paddle.inference import Config
from paddle.inference import create_predictor
class Predictor(object):
def __init__(self, args):
# HALF precission predict only work when using tensorrt
if args.use_fp16 is True:
assert args.use_tensorrt is True
self.args = args
self.paddle_predictor = self.create_paddle_predictor(args)
input_names = self.paddle_predictor.get_input_names()
self.input_tensor = self.paddle_predictor.get_input_handle(input_names[
0])
output_names = self.paddle_predictor.get_output_names()
self.output_tensor = self.paddle_predictor.get_output_handle(
output_names[0])
def predict(self, batch_input):
self.input_tensor.copy_from_cpu(batch_input)
self.paddle_predictor.run()
batch_output = self.output_tensor.copy_to_cpu()
return batch_output
def create_paddle_predictor(self, args):
params_file = os.path.join(args.inference_model_dir,
"inference.pdiparams")
model_file = os.path.join(args.inference_model_dir,
"inference.pdmodel")
config = Config(model_file, params_file)
if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0)
else:
config.disable_gpu()
if args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.set_cpu_math_library_num_threads(args.cpu_num_threads)
if args.enable_profile:
config.enable_profile()
config.disable_glog_info()
config.switch_ir_optim(args.ir_optim) # default true
if args.use_tensorrt:
config.enable_tensorrt_engine(
precision_mode=Config.Precision.Half
if args.use_fp16 else Config.Precision.Float32,
max_batch_size=args.batch_size)
config.enable_memory_optim()
# use zero copy
config.switch_use_feed_fetch_ops(False)
predictor = create_predictor(config)
return predictor
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import cv2
import time
import sys
sys.path.insert(0, ".")
from ppcls.utils import logger
from tools.infer.utils import parse_args, create_paddle_predictor, preprocess, postprocess
from tools.infer.utils import get_image_list, get_image_list_from_label_file, calc_topk_acc
class Predictor(object):
def __init__(self, args):
# HALF precission predict only work when using tensorrt
if args.use_fp16 is True:
assert args.use_tensorrt is True
self.args = args
self.paddle_predictor = create_paddle_predictor(args)
input_names = self.paddle_predictor.get_input_names()
self.input_tensor = self.paddle_predictor.get_input_handle(input_names[
0])
output_names = self.paddle_predictor.get_output_names()
self.output_tensor = self.paddle_predictor.get_output_handle(
output_names[0])
def predict(self, batch_input):
self.input_tensor.copy_from_cpu(batch_input)
self.paddle_predictor.run()
batch_output = self.output_tensor.copy_to_cpu()
return batch_output
def normal_predict(self):
if self.args.enable_calc_topk:
assert self.args.gt_label_path is not None and os.path.exists(self.args.gt_label_path), \
"gt_label_path shoule not be None and must exist, please check its path."
image_list, gt_labels = get_image_list_from_label_file(
self.args.image_file, self.args.gt_label_path)
predicts_map = {
"prediction": [],
"gt_label": [],
}
else:
image_list = get_image_list(self.args.image_file)
gt_labels = None
batch_input_list = []
img_name_list = []
cnt = 0
for idx, img_path in enumerate(image_list):
img = cv2.imread(img_path)
if img is None:
logger.warning(
"Image file failed to read and has been skipped. The path: {}".
format(img_path))
continue
else:
img = img[:, :, ::-1]
img = preprocess(img, args)
batch_input_list.append(img)
img_name = img_path.split("/")[-1]
img_name_list.append(img_name)
cnt += 1
if self.args.enable_calc_topk:
predicts_map["gt_label"].append(gt_labels[idx])
if cnt % args.batch_size == 0 or (idx + 1) == len(image_list):
batch_outputs = self.predict(np.array(batch_input_list))
batch_result_list = postprocess(batch_outputs, self.args.top_k)
for number, result_dict in enumerate(batch_result_list):
filename = img_name_list[number]
clas_ids = result_dict["clas_ids"]
scores_str = "[{}]".format(", ".join("{:.2f}".format(
r) for r in result_dict["scores"]))
logger.info(
"File:{}, Top-{} result: class id(s): {}, score(s): {}".
format(filename, self.args.top_k, clas_ids,
scores_str))
if self.args.enable_calc_topk:
predicts_map["prediction"].append(clas_ids)
batch_input_list = []
img_name_list = []
if self.args.enable_calc_topk:
topk_acc = calc_topk_acc(predicts_map)
for idx, acc in enumerate(topk_acc):
logger.info("Top-{} acc: {:.5f}".format(idx + 1, acc))
def benchmark_predict(self):
test_num = 500
test_time = 0.0
for i in range(0, test_num + 10):
inputs = np.random.rand(args.batch_size, 3, 224,
224).astype(np.float32)
start_time = time.time()
batch_output = self.predict(inputs).flatten()
if i >= 10:
test_time += time.time() - start_time
time.sleep(0.01) # sleep for T4 GPU
fp_message = "FP16" if args.use_fp16 else "FP32"
trt_msg = "using tensorrt" if args.use_tensorrt else "not using tensorrt"
print("{0}\t{1}\t{2}\tbatch size: {3}\ttime(ms): {4}".format(
args.model, trt_msg, fp_message, args.batch_size, 1000 * test_time
/ test_num))
if __name__ == "__main__":
args = parse_args()
assert os.path.exists(
args.model_file), "The path of 'model_file' does not exist: {}".format(
args.model_file)
assert os.path.exists(
args.params_file
), "The path of 'params_file' does not exist: {}".format(args.params_file)
predictor = Predictor(args)
if not args.enable_benchmark:
predictor.normal_predict()
else:
assert args.model is not None
predictor.benchmark_predict()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import base64
import shutil
import cv2
import numpy as np
from paddle.inference import Config
from paddle.inference import create_predictor
def parse_args():
def str2bool(v):
return v.lower() in ("true", "t", "1")
# general params
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_file", type=str)
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--multilabel", type=str2bool, default=False)
# params for preprocess
parser.add_argument("--resize_short", type=int, default=256)
parser.add_argument("--resize", type=int, default=224)
parser.add_argument("--normalize", type=str2bool, default=True)
# params for predict
parser.add_argument("--model_file", type=str)
parser.add_argument("--params_file", type=str)
parser.add_argument("-b", "--batch_size", type=int, default=1)
parser.add_argument("--use_fp16", type=str2bool, default=False)
parser.add_argument("--ir_optim", type=str2bool, default=True)
parser.add_argument("--use_tensorrt", type=str2bool, default=False)
parser.add_argument("--gpu_mem", type=int, default=8000)
parser.add_argument("--enable_profile", type=str2bool, default=False)
parser.add_argument("--enable_benchmark", type=str2bool, default=False)
parser.add_argument("--top_k", type=int, default=1)
parser.add_argument("--enable_mkldnn", type=str2bool, default=False)
parser.add_argument("--cpu_num_threads", type=int, default=10)
parser.add_argument("--hubserving", type=str2bool, default=False)
# params for infer
parser.add_argument("--model", type=str)
parser.add_argument("--pretrained_model", type=str)
parser.add_argument("--class_num", type=int, default=1000)
parser.add_argument(
"--load_static_weights",
type=str2bool,
default=False,
help='Whether to load the pretrained weights saved in static mode')
# parameters for pre-label the images
parser.add_argument(
"--pre_label_image",
type=str2bool,
default=False,
help="Whether to pre-label the images using the loaded weights")
parser.add_argument("--pre_label_out_idr", type=str, default=None)
# parameters for test hubserving
parser.add_argument("--server_url", type=str)
# enable_calc_metric, when set as true, topk acc will be calculated
parser.add_argument("--enable_calc_topk", type=str2bool, default=False)
# groudtruth label path
# data format for each line: $image_name $class_id
parser.add_argument("--gt_label_path", type=str, default=None)
return parser.parse_args()
def create_paddle_predictor(args):
config = Config(args.model_file, args.params_file)
if args.use_gpu:
config.enable_use_gpu(args.gpu_mem, 0)
else:
config.disable_gpu()
if args.enable_mkldnn:
# cache 10 different shapes for mkldnn to avoid memory leak
config.set_mkldnn_cache_capacity(10)
config.enable_mkldnn()
config.set_cpu_math_library_num_threads(args.cpu_num_threads)
if args.enable_profile:
config.enable_profile()
config.disable_glog_info()
config.switch_ir_optim(args.ir_optim) # default true
if args.use_tensorrt:
config.enable_tensorrt_engine(
precision_mode=Config.Precision.Half
if args.use_fp16 else Config.Precision.Float32,
max_batch_size=args.batch_size)
config.enable_memory_optim()
# use zero copy
config.switch_use_feed_fetch_ops(False)
predictor = create_predictor(config)
return predictor
def preprocess(img, args):
resize_op = ResizeImage(resize_short=args.resize_short)
img = resize_op(img)
crop_op = CropImage(size=(args.resize, args.resize))
img = crop_op(img)
if args.normalize:
img_mean = [0.485, 0.456, 0.406]
img_std = [0.229, 0.224, 0.225]
img_scale = 1.0 / 255.0
normalize_op = NormalizeImage(
scale=img_scale, mean=img_mean, std=img_std)
img = normalize_op(img)
tensor_op = ToTensor()
img = tensor_op(img)
return img
def postprocess(batch_outputs, topk=5, multilabel=False):
batch_results = []
for probs in batch_outputs:
if multilabel:
index = np.where(probs >= 0.5)[0].astype('int32')
else:
index = probs.argsort(axis=0)[-topk:][::-1].astype("int32")
clas_id_list = []
score_list = []
for i in index:
clas_id_list.append(i.item())
score_list.append(probs[i].item())
batch_results.append({"clas_ids": clas_id_list, "scores": score_list})
return batch_results
def get_image_list(img_file):
imgs_lists = []
if img_file is None or not os.path.exists(img_file):
raise Exception("not found any img file in {}".format(img_file))
img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp']
if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end:
imgs_lists.append(img_file)
elif os.path.isdir(img_file):
for single_file in os.listdir(img_file):
if single_file.split('.')[-1] in img_end:
imgs_lists.append(os.path.join(img_file, single_file))
if len(imgs_lists) == 0:
raise Exception("not found any img file in {}".format(img_file))
return imgs_lists
def get_image_list_from_label_file(image_path, label_file_path):
imgs_lists = []
gt_labels = []
with open(label_file_path, "r") as fin:
lines = fin.readlines()
for line in lines:
image_name, label = line.strip("\n").split()
label = int(label)
imgs_lists.append(os.path.join(image_path, image_name))
gt_labels.append(int(label))
return imgs_lists, gt_labels
def calc_topk_acc(info_map):
'''
calc_topk_acc
input:
info_map(dict): keys are prediction and gt_label
output:
topk_acc(list): top-k accuracy list
'''
gt_label = np.array(info_map["gt_label"])
prediction = np.array(info_map["prediction"])
gt_label = np.reshape(gt_label, (-1, 1)).repeat(
prediction.shape[1], axis=1)
correct = np.equal(prediction, gt_label)
topk_acc = []
for idx in range(prediction.shape[1]):
if idx > 0:
correct[:, idx] = np.logical_or(correct[:, idx],
correct[:, idx - 1])
topk_acc.append(1.0 * np.sum(correct[:, idx]) / correct.shape[0])
return topk_acc
def save_prelabel_results(class_id, input_file_path, output_dir):
output_dir = os.path.join(output_dir, str(class_id))
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
shutil.copy(input_file_path, output_dir)
class ResizeImage(object):
def __init__(self, resize_short=None):
self.resize_short = resize_short
def __call__(self, img):
img_h, img_w = img.shape[:2]
percent = float(self.resize_short) / min(img_w, img_h)
w = int(round(img_w * percent))
h = int(round(img_h * percent))
return cv2.resize(img, (w, h))
class CropImage(object):
def __init__(self, size):
if type(size) is int:
self.size = (size, size)
else:
self.size = size
def __call__(self, img):
w, h = self.size
img_h, img_w = img.shape[:2]
w_start = (img_w - w) // 2
h_start = (img_h - h) // 2
w_end = w_start + w
h_end = h_start + h
return img[h_start:h_end, w_start:w_end, :]
class NormalizeImage(object):
def __init__(self, scale=None, mean=None, std=None):
self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
mean = mean if mean is not None else [0.485, 0.456, 0.406]
std = std if std is not None else [0.229, 0.224, 0.225]
shape = (1, 1, 3)
self.mean = np.array(mean).reshape(shape).astype('float32')
self.std = np.array(std).reshape(shape).astype('float32')
def __call__(self, img):
return (img.astype('float32') * self.scale - self.mean) / self.std
class ToTensor(object):
def __init__(self):
pass
def __call__(self, img):
img = img.transpose((2, 0, 1))
return img
def b64_to_np(b64str, revert_params):
shape = revert_params["shape"]
dtype = revert_params["dtype"]
dtype = getattr(np, dtype) if isinstance(str, type(dtype)) else dtype
data = base64.b64decode(b64str.encode('utf8'))
data = np.fromstring(data, dtype).reshape(shape)
return data
def np_to_b64(images):
img_str = base64.b64encode(images).decode('utf8')
return img_str, images.shape
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册