From 0a0e34298b8cc4882458659d7225f6827fa0d235 Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Tue, 1 Jun 2021 14:37:47 +0800 Subject: [PATCH] add support for infer and export (#760) * add support for infer * add export model support * fix yaml * fix post process name * fix topk name --- ppcls/configs/ImageNet/ResNet/ResNet50.yaml | 25 +++++- ppcls/data/postprocess/__init__.py | 27 +++++++ ppcls/data/postprocess/topk.py | 75 ++++++++++++++++++ ppcls/data/utils/__init__.py | 13 ++++ ppcls/data/utils/get_image_list.py | 49 ++++++++++++ ppcls/engine/trainer.py | 37 +++++++++ tools/export_model.py | 85 ++++++++++----------- tools/infer.py | 31 ++++++++ 8 files changed, 295 insertions(+), 47 deletions(-) create mode 100644 ppcls/data/postprocess/__init__.py create mode 100644 ppcls/data/postprocess/topk.py create mode 100644 ppcls/data/utils/__init__.py create mode 100644 ppcls/data/utils/get_image_list.py create mode 100644 tools/infer.py diff --git a/ppcls/configs/ImageNet/ResNet/ResNet50.yaml b/ppcls/configs/ImageNet/ResNet/ResNet50.yaml index 2e9cfaf9..a1a42f69 100644 --- a/ppcls/configs/ImageNet/ResNet/ResNet50.yaml +++ b/ppcls/configs/ImageNet/ResNet/ResNet50.yaml @@ -11,8 +11,9 @@ Global: epochs: 120 print_batch_step: 10 use_visualdl: False + # used for static mode and model export image_shape: [3, 224, 224] - infer_imgs: + save_inference_dir: "./inference" # model architecture Arch: @@ -91,6 +92,28 @@ DataLoader: order: '' - ToCHWImage: +Infer: + infer_imgs: "docs/images/whl/demo.jpg" + batch_size: 10 + transforms: + - DecodeImage: + to_rgb: True + channel_first: False + - ResizeImage: + resize_short: 256 + - CropImage: + size: 224 + - NormalizeImage: + scale: 1.0/255.0 + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + order: '' + - ToCHWImage: + PostProcess: + name: Topk + topk: 5 + class_id_map_file: "ppcls/utils/imagenet1k_label_list.txt" + Metric: Train: - Topk: diff --git a/ppcls/data/postprocess/__init__.py b/ppcls/data/postprocess/__init__.py new file mode 100644 index 00000000..8d78ec81 --- /dev/null +++ b/ppcls/data/postprocess/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import importlib + +from . import topk_process + +from .topk_process import Topk + + +def build_postprocess(config): + config = copy.deepcopy(config) + model_name = config.pop("name") + mod = importlib.import_module(__name__) + postprocess_func = getattr(mod, model_name)(**config) + return postprocess_func diff --git a/ppcls/data/postprocess/topk.py b/ppcls/data/postprocess/topk.py new file mode 100644 index 00000000..2410e329 --- /dev/null +++ b/ppcls/data/postprocess/topk.py @@ -0,0 +1,75 @@ +# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +import paddle +import paddle.nn.functional as F + + +class Topk(object): + def __init__(self, topk=1, class_id_map_file=None): + assert isinstance(topk, (int, )) + self.class_id_map = self.parse_class_id_map(class_id_map_file) + self.topk = topk + + def parse_class_id_map(self, class_id_map_file): + if class_id_map_file is None: + return None + if not os.path.exists(class_id_map_file): + print( + "Warning: If want to use your own label_dict, please input legal path!\nOtherwise label_names will be empty!" + ) + return None + + try: + class_id_map = {} + with open(class_id_map_file, "r") as fin: + lines = fin.readlines() + for line in lines: + partition = line.split("\n")[0].partition(" ") + class_id_map[int(partition[0])] = str(partition[-1]) + except Exception as ex: + print(ex) + class_id_map = None + return class_id_map + + def __call__(self, x, file_names=None): + assert isinstance(x, paddle.Tensor) + if file_names is not None: + assert x.shape[0] == len(file_names) + x = F.softmax(x, axis=-1) + x = x.numpy() + y = [] + for idx, probs in enumerate(x): + index = probs.argsort(axis=0)[-self.topk:][::-1].astype("int32") + clas_id_list = [] + score_list = [] + label_name_list = [] + for i in index: + clas_id_list.append(i.item()) + score_list.append(probs[i].item()) + if self.class_id_map is not None: + label_name_list.append(self.class_id_map[i.item()]) + result = { + "class_ids": clas_id_list, + "scores": np.around( + score_list, decimals=5).tolist(), + } + if file_names is not None: + result["file_name"] = file_names[idx] + if label_name_list is not None: + result["label_names"] = label_name_list + y.append(result) + return y diff --git a/ppcls/data/utils/__init__.py b/ppcls/data/utils/__init__.py new file mode 100644 index 00000000..61d5aa21 --- /dev/null +++ b/ppcls/data/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/ppcls/data/utils/get_image_list.py b/ppcls/data/utils/get_image_list.py new file mode 100644 index 00000000..6f10935a --- /dev/null +++ b/ppcls/data/utils/get_image_list.py @@ -0,0 +1,49 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import argparse +import base64 +import numpy as np + + +def get_image_list(img_file): + imgs_lists = [] + if img_file is None or not os.path.exists(img_file): + raise Exception("not found any img file in {}".format(img_file)) + + img_end = ['jpg', 'png', 'jpeg', 'JPEG', 'JPG', 'bmp'] + if os.path.isfile(img_file) and img_file.split('.')[-1] in img_end: + imgs_lists.append(img_file) + elif os.path.isdir(img_file): + for single_file in os.listdir(img_file): + if single_file.split('.')[-1] in img_end: + imgs_lists.append(os.path.join(img_file, single_file)) + if len(imgs_lists) == 0: + raise Exception("not found any img file in {}".format(img_file)) + imgs_lists = sorted(imgs_lists) + return imgs_lists + + +def get_image_list_from_label_file(image_path, label_file_path): + imgs_lists = [] + gt_labels = [] + with open(label_file_path, "r") as fin: + lines = fin.readlines() + for line in lines: + image_name, label = line.strip("\n").split() + label = int(label) + imgs_lists.append(os.path.join(image_path, image_name)) + gt_labels.append(int(label)) + return imgs_lists, gt_labels diff --git a/ppcls/engine/trainer.py b/ppcls/engine/trainer.py index de219082..458ddb35 100644 --- a/ppcls/engine/trainer.py +++ b/ppcls/engine/trainer.py @@ -37,6 +37,10 @@ from ppcls.utils.save_load import load_dygraph_pretrain from ppcls.utils.save_load import init_model from ppcls.utils import save_load +from ppcls.data.utils.get_image_list import get_image_list +from ppcls.data.postprocess import build_postprocess +from ppcls.data.reader import create_operators + class Trainer(object): def __init__(self, config, mode="train"): @@ -277,3 +281,36 @@ class Trainer(object): return -1 # return 1st metric in the dict return output_info[metric_key].avg + + @paddle.no_grad() + def infer(self, ): + total_trainer = paddle.distributed.get_world_size() + local_rank = paddle.distributed.get_rank() + image_list = get_image_list(self.config["Infer"]["infer_imgs"]) + # data split + image_list = image_list[local_rank::total_trainer] + + preprocess_func = create_operators(self.config["Infer"]["transforms"]) + postprocess_func = build_postprocess(self.config["Infer"][ + "PostProcess"]) + + batch_size = self.config["Infer"]["batch_size"] + + self.model.eval() + + batch_data = [] + image_file_list = [] + for idx, image_file in enumerate(image_list): + with open(image_file, 'rb') as f: + x = f.read() + for process in preprocess_func: + x = process(x) + batch_data.append(x) + image_file_list.append(image_file) + if len(batch_data) >= batch_size or idx == len(image_list) - 1: + batch_tensor = paddle.to_tensor(batch_data) + out = self.model(batch_tensor) + result = postprocess_func(out, image_file_list) + print(result) + batch_data.clear() + image_file_list.clear() diff --git a/tools/export_model.py b/tools/export_model.py index c3a06fac..86e84eaa 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,40 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import argparse +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function import os import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(__dir__) -sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) -from ppcls.arch import backbone -from ppcls.utils.save_load import load_dygraph_pretrain import paddle -import paddle.nn.functional as F -from paddle.jit import to_static - - -def parse_args(): - def str2bool(v): - return v.lower() in ("true", "t", "1") +import paddle.nn as nn - parser = argparse.ArgumentParser() - parser.add_argument("-m", "--model", type=str) - parser.add_argument("-p", "--pretrained_model", type=str) - parser.add_argument("-o", "--output_path", type=str, default="./inference") - parser.add_argument("--class_dim", type=int, default=1000) - parser.add_argument("--load_static_weights", type=str2bool, default=False) - parser.add_argument("--img_size", type=int, default=224) +from ppcls.utils import config +from ppcls.engine.trainer import Trainer +from ppcls.arch import build_model +from ppcls.utils.save_load import load_dygraph_pretrain - return parser.parse_args() +class ClasModel(nn.Layer): + """ + ClasModel: add softmax onto the model + """ -class Net(paddle.nn.Layer): - def __init__(self, net, class_dim, model): - super(Net, self).__init__() - self.pre_net = net(class_dim=class_dim) - self.model = model + def __init__(self, config): + super().__init__() + self.base_model = build_model(config) + self.softmax = nn.Softmax(axis=-1) def eval(self): self.training = False @@ -53,33 +45,34 @@ class Net(paddle.nn.Layer): layer.training = False layer.eval() - def forward(self, inputs): - x = self.pre_net(inputs) - if self.model == "GoogLeNet": - x = x[0] - x = F.softmax(x) + def forward(self, x): + x = self.base_model(x) + x = self.softmax(x) return x -def main(): - args = parse_args() +if __name__ == "__main__": + args = config.parse_args() + config = config.get_config(args.config, overrides=args.override, show=True) + # set device + assert config["Global"]["device"] in ["cpu", "gpu", "xpu"] + device = paddle.set_device(config["Global"]["device"]) + + model = ClasModel(config["Arch"]) + + if config["Global"]["pretrained_model"] is not None: + load_dygraph_pretrain(model.base_model, + config["Global"]["pretrained_model"]) - net = backbone.__dict__[args.model] - model = Net(net, args.class_dim, args.model) - load_dygraph_pretrain( - model.pre_net, - path=args.pretrained_model, - load_static_weights=args.load_static_weights) model.eval() - model = to_static( + model = paddle.jit.to_static( model, input_spec=[ paddle.static.InputSpec( - shape=[None, 3, args.img_size, args.img_size], dtype='float32') + shape=[None] + config["Global"]["image_shape"], + dtype='float32') ]) - paddle.jit.save(model, os.path.join(args.output_path, "inference")) - - -if __name__ == "__main__": - main() + paddle.jit.save(model, + os.path.join(config["Global"]["save_inference_dir"], + "inference")) diff --git a/tools/infer.py b/tools/infer.py new file mode 100644 index 00000000..256037a7 --- /dev/null +++ b/tools/infer.py @@ -0,0 +1,31 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) + +from ppcls.utils import config +from ppcls.engine.trainer import Trainer + +if __name__ == "__main__": + args = config.parse_args() + config = config.get_config(args.config, overrides=args.override, show=True) + trainer = Trainer(config, mode="infer") + + trainer.infer() -- GitLab