diff --git a/PaddleCV/PaddleDetection/tools/configure.py b/PaddleCV/PaddleDetection/tools/configure.py index fa116276594b295f02f0d409a8a3883442f53b16..1ac15864dfc66d2ba1e4455933eb7e6dacb5954b 100644 --- a/PaddleCV/PaddleDetection/tools/configure.py +++ b/PaddleCV/PaddleDetection/tools/configure.py @@ -1,280 +1,280 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import re -import sys -from argparse import ArgumentParser, RawDescriptionHelpFormatter - -import yaml - -import sys -sys.path.append('..') - -from ppdet.core.workspace import get_registered_modules, load_config -from ppdet.utils.cli import ColorTTY - -color_tty = ColorTTY() - -MISC_CONFIG = { - "architecture": "", - "max_iters": "", - "train_feed": "", - "eval_feed": "", - "test_feed": "", - "pretrain_weights": "", - "save_dir": "", - "weights": "", - "metric": "", - "log_smooth_window": 20, - "snapshot_iter": 10000, - "use_gpu": True, -} - - -def dump_value(value): - # XXX this is hackish, but collections.abc is not available in python 2 - if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)): - value = yaml.dump(value, default_flow_style=True) - value = value.replace('\n', '') - value = value.replace('...', '') - return "'{}'".format(value) - else: - # primitive types - return str(value) - - -def dump_config(module, minimal=False): - args = module.schema.values() - if minimal: - args = [arg for arg in args if not arg.has_default()] - return yaml.dump( - { - module.name: { - arg.name: arg.default if arg.has_default() else "" - for arg in args - } - }, - default_flow_style=False, - default_style='') - - -def list_modules(**kwargs): - target_category = kwargs['category'] - module_schema = get_registered_modules() - module_by_category = {} - - for schema in module_schema.values(): - category = schema.category - if target_category is not None and schema.category != target_category: - continue - if category not in module_by_category: - module_by_category[category] = [schema] - else: - module_by_category[category].append(schema) - - for cat, modules in module_by_category.items(): - print("Available modules in the category '{}':".format(cat)) - print("") - max_len = max([len(mod.name) for mod in modules]) - for mod in modules: - print(color_tty.green(mod.name.ljust(max_len)), - mod.doc.split('\n')[0]) - print("") - - -def help_module(**kwargs): - schema = get_registered_modules()[kwargs['module']] - - doc = schema.doc is None and "Not documented" or "{}".format(schema.doc) - func_args = {arg.name: arg.doc for arg in schema.schema.values()} - max_len = max([len(k) for k in func_args.keys()]) - opts = "\n".join([ - "{} {}".format(color_tty.green(k.ljust(max_len)), v) - for k, v in func_args.items() - ]) - template = dump_config(schema) - print("{}\n\n{}\n\n{}\n\n{}\n\n{}\n\n{}\n{}\n".format( - color_tty.bold(color_tty.blue("MODULE DESCRIPTION:")), - doc, - color_tty.bold(color_tty.blue("MODULE OPTIONS:")), - opts, - color_tty.bold(color_tty.blue("CONFIGURATION TEMPLATE:")), - template, - color_tty.bold(color_tty.blue("COMMAND LINE OPTIONS:")), )) - for arg in schema.schema.values(): - print("--opt {}.{}={}".format(schema.name, arg.name, - dump_value(arg.default) - if arg.has_default() else "")) - - -def generate_config(**kwargs): - minimal = kwargs['minimal'] - modules = kwargs['modules'] - module_schema = get_registered_modules() - visited = [] - schema = [] - - def walk(m): - if m in visited: - return - s = module_schema[m] - schema.append(s) - visited.append(m) - - for mod in modules: - walk(mod) - - # XXX try to be smart about when to add header, - # if any "architecture" module, is included, head will be added as well - if any([getattr(m, 'category', None) == 'architecture' for m in schema]): - # XXX for ordered printing - header = "" - for k, v in MISC_CONFIG.items(): - header += yaml.dump( - { - k: v - }, default_flow_style=False, default_style='') - print(header) - - for s in schema: - print(dump_config(s, minimal)) - - -# FIXME this is pretty hackish, maybe implement a custom YAML printer? -def analyze_config(**kwargs): - config = load_config(kwargs['file']) - modules = get_registered_modules() - green = '___{}___'.format(color_tty.colors.index('green') + 31) - - styled = {} - for key in config.keys(): - if not config[key]: # empty schema - continue - - if key not in modules and not hasattr(config[key], '__dict__'): - styled[key] = config[key] - continue - elif key in modules: - module = modules[key] - else: - type_name = type(config[key]).__name__ - if type_name in modules: - module = modules[type_name].copy() - module.update({ - k: v - for k, v in config[key].__dict__.items() - if k in module.schema - }) - key += " ({})".format(type_name) - default = module.find_default_keys() - missing = module.find_missing_keys() - mismatch = module.find_mismatch_keys() - extra = module.find_extra_keys() - dep_missing = [] - for dep in module.inject: - if isinstance(module[dep], str) and module[dep] != '': - if module[dep] not in modules: # not a valid module - dep_missing.append(dep) - else: - dep_mod = modules[module[dep]] - # empty dict but mandatory - if not dep_mod and dep_mod.mandatory(): - dep_missing.append(dep) - override = list( - set(module.keys()) - set(default) - set(extra) - set(dep_missing)) - replacement = {} - for name in set(override + default + extra + mismatch + missing): - new_name = name - if name in missing: - value = "" - else: - value = module[name] - - if name in extra: - value = dump_value(value) + " " - elif name in mismatch: - value = dump_value(value) + " " - elif name in dep_missing: - value = dump_value(value) + " " - elif name in override and value != '': - mark = green - new_name = mark + name - replacement[new_name] = value - styled[key] = replacement - buffer = yaml.dump(styled, default_flow_style=False, default_style='') - buffer = (re.sub(r"", r"", buffer)) - buffer = (re.sub(r"", r"", buffer)) - buffer = (re.sub(r"", r"", buffer)) - buffer = (re.sub(r"", - r"", buffer)) - buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2:", buffer) - print(buffer) - - -if __name__ == '__main__': - argv = sys.argv[1:] - - parser = ArgumentParser(formatter_class=RawDescriptionHelpFormatter) - subparsers = parser.add_subparsers(help='Supported Commands') - list_parser = subparsers.add_parser("list", help="list available modules") - help_parser = subparsers.add_parser( - "help", help="show detail options for module") - generate_parser = subparsers.add_parser( - "generate", help="generate configuration template") - analyze_parser = subparsers.add_parser( - "analyze", help="analyze configuration file") - - list_parser.set_defaults(func=list_modules) - help_parser.set_defaults(func=help_module) - generate_parser.set_defaults(func=generate_config) - analyze_parser.set_defaults(func=analyze_config) - - list_group = list_parser.add_mutually_exclusive_group() - list_group.add_argument( - "-c", - "--category", - type=str, - default=None, - help="list modules for ") - - help_parser.add_argument( - "module", - help="module to show info for", - choices=list(get_registered_modules().keys())) - - generate_parser.add_argument( - "modules", - nargs='+', - help="include these module in generated configuration template", - choices=list(get_registered_modules().keys())) - generate_group = generate_parser.add_mutually_exclusive_group() - generate_group.add_argument( - "--minimal", action='store_true', help="only include required options") - generate_group.add_argument( - "--full", - action='store_false', - dest='minimal', - help="include all options") - - analyze_parser.add_argument("file", help="configuration file to analyze") - - if len(sys.argv) < 2: - parser.print_help() - sys.exit(1) - - args = parser.parse_args(argv) - if hasattr(args, 'func'): - args.func(**vars(args)) +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import re +import sys +from argparse import ArgumentParser, RawDescriptionHelpFormatter + +import yaml + +import sys +sys.path.append('..') + +from ppdet.core.workspace import get_registered_modules, load_config +from ppdet.utils.cli import ColorTTY + +color_tty = ColorTTY() + +MISC_CONFIG = { + "architecture": "", + "max_iters": "", + "train_feed": "", + "eval_feed": "", + "test_feed": "", + "pretrain_weights": "", + "save_dir": "", + "weights": "", + "metric": "", + "log_smooth_window": 20, + "snapshot_iter": 10000, + "use_gpu": True, +} + + +def dump_value(value): + # XXX this is hackish, but collections.abc is not available in python 2 + if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)): + value = yaml.dump(value, default_flow_style=True) + value = value.replace('\n', '') + value = value.replace('...', '') + return "'{}'".format(value) + else: + # primitive types + return str(value) + + +def dump_config(module, minimal=False): + args = module.schema.values() + if minimal: + args = [arg for arg in args if not arg.has_default()] + return yaml.dump( + { + module.name: { + arg.name: arg.default if arg.has_default() else "" + for arg in args + } + }, + default_flow_style=False, + default_style='') + + +def list_modules(**kwargs): + target_category = kwargs['category'] + module_schema = get_registered_modules() + module_by_category = {} + + for schema in module_schema.values(): + category = schema.category + if target_category is not None and schema.category != target_category: + continue + if category not in module_by_category: + module_by_category[category] = [schema] + else: + module_by_category[category].append(schema) + + for cat, modules in module_by_category.items(): + print("Available modules in the category '{}':".format(cat)) + print("") + max_len = max([len(mod.name) for mod in modules]) + for mod in modules: + print(color_tty.green(mod.name.ljust(max_len)), + mod.doc.split('\n')[0]) + print("") + + +def help_module(**kwargs): + schema = get_registered_modules()[kwargs['module']] + + doc = schema.doc is None and "Not documented" or "{}".format(schema.doc) + func_args = {arg.name: arg.doc for arg in schema.schema.values()} + max_len = max([len(k) for k in func_args.keys()]) + opts = "\n".join([ + "{} {}".format(color_tty.green(k.ljust(max_len)), v) + for k, v in func_args.items() + ]) + template = dump_config(schema) + print("{}\n\n{}\n\n{}\n\n{}\n\n{}\n\n{}\n{}\n".format( + color_tty.bold(color_tty.blue("MODULE DESCRIPTION:")), + doc, + color_tty.bold(color_tty.blue("MODULE OPTIONS:")), + opts, + color_tty.bold(color_tty.blue("CONFIGURATION TEMPLATE:")), + template, + color_tty.bold(color_tty.blue("COMMAND LINE OPTIONS:")), )) + for arg in schema.schema.values(): + print("--opt {}.{}={}".format(schema.name, arg.name, + dump_value(arg.default) + if arg.has_default() else "")) + + +def generate_config(**kwargs): + minimal = kwargs['minimal'] + modules = kwargs['modules'] + module_schema = get_registered_modules() + visited = [] + schema = [] + + def walk(m): + if m in visited: + return + s = module_schema[m] + schema.append(s) + visited.append(m) + + for mod in modules: + walk(mod) + + # XXX try to be smart about when to add header, + # if any "architecture" module, is included, head will be added as well + if any([getattr(m, 'category', None) == 'architecture' for m in schema]): + # XXX for ordered printing + header = "" + for k, v in MISC_CONFIG.items(): + header += yaml.dump( + { + k: v + }, default_flow_style=False, default_style='') + print(header) + + for s in schema: + print(dump_config(s, minimal)) + + +# FIXME this is pretty hackish, maybe implement a custom YAML printer? +def analyze_config(**kwargs): + config = load_config(kwargs['file']) + modules = get_registered_modules() + green = '___{}___'.format(color_tty.colors.index('green') + 31) + + styled = {} + for key in config.keys(): + if not config[key]: # empty schema + continue + + if key not in modules and not hasattr(config[key], '__dict__'): + styled[key] = config[key] + continue + elif key in modules: + module = modules[key] + else: + type_name = type(config[key]).__name__ + if type_name in modules: + module = modules[type_name].copy() + module.update({ + k: v + for k, v in config[key].__dict__.items() + if k in module.schema + }) + key += " ({})".format(type_name) + default = module.find_default_keys() + missing = module.find_missing_keys() + mismatch = module.find_mismatch_keys() + extra = module.find_extra_keys() + dep_missing = [] + for dep in module.inject: + if isinstance(module[dep], str) and module[dep] != '': + if module[dep] not in modules: # not a valid module + dep_missing.append(dep) + else: + dep_mod = modules[module[dep]] + # empty dict but mandatory + if not dep_mod and dep_mod.mandatory(): + dep_missing.append(dep) + override = list( + set(module.keys()) - set(default) - set(extra) - set(dep_missing)) + replacement = {} + for name in set(override + default + extra + mismatch + missing): + new_name = name + if name in missing: + value = "" + else: + value = module[name] + + if name in extra: + value = dump_value(value) + " " + elif name in mismatch: + value = dump_value(value) + " " + elif name in dep_missing: + value = dump_value(value) + " " + elif name in override and value != '': + mark = green + new_name = mark + name + replacement[new_name] = value + styled[key] = replacement + buffer = yaml.dump(styled, default_flow_style=False, default_style='') + buffer = (re.sub(r"", r"", buffer)) + buffer = (re.sub(r"", r"", buffer)) + buffer = (re.sub(r"", r"", buffer)) + buffer = (re.sub(r"", + r"", buffer)) + buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2:", buffer) + print(buffer) + + +if __name__ == '__main__': + argv = sys.argv[1:] + + parser = ArgumentParser(formatter_class=RawDescriptionHelpFormatter) + subparsers = parser.add_subparsers(help='Supported Commands') + list_parser = subparsers.add_parser("list", help="list available modules") + help_parser = subparsers.add_parser( + "help", help="show detail options for module") + generate_parser = subparsers.add_parser( + "generate", help="generate configuration template") + analyze_parser = subparsers.add_parser( + "analyze", help="analyze configuration file") + + list_parser.set_defaults(func=list_modules) + help_parser.set_defaults(func=help_module) + generate_parser.set_defaults(func=generate_config) + analyze_parser.set_defaults(func=analyze_config) + + list_group = list_parser.add_mutually_exclusive_group() + list_group.add_argument( + "-c", + "--category", + type=str, + default=None, + help="list modules for ") + + help_parser.add_argument( + "module", + help="module to show info for", + choices=list(get_registered_modules().keys())) + + generate_parser.add_argument( + "modules", + nargs='+', + help="include these module in generated configuration template", + choices=list(get_registered_modules().keys())) + generate_group = generate_parser.add_mutually_exclusive_group() + generate_group.add_argument( + "--minimal", action='store_true', help="only include required options") + generate_group.add_argument( + "--full", + action='store_false', + dest='minimal', + help="include all options") + + analyze_parser.add_argument("file", help="configuration file to analyze") + + if len(sys.argv) < 2: + parser.print_help() + sys.exit(1) + + args = parser.parse_args(argv) + if hasattr(args, 'func'): + args.func(**vars(args)) diff --git a/PaddleCV/PaddleDetection/tools/eval.py b/PaddleCV/PaddleDetection/tools/eval.py index 8469732aafe3b94451dd9d00c55d00035ac3aac9..c688cab958e0a38dfcaa7104437d77c88998ae38 100644 --- a/PaddleCV/PaddleDetection/tools/eval.py +++ b/PaddleCV/PaddleDetection/tools/eval.py @@ -1,122 +1,122 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import multiprocessing - -import paddle.fluid as fluid - -import sys -sys.path.append('..') - -from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results -import ppdet.utils.checkpoint as checkpoint -from ppdet.utils.cli import ArgsParser -from ppdet.utils.check import check_gpu -from ppdet.modeling.model_input import create_feed -from ppdet.data.data_feed import create_reader -from ppdet.core.workspace import load_config, merge_config, create - -import logging -FORMAT = '%(asctime)s-%(levelname)s: %(message)s' -logging.basicConfig(level=logging.INFO, format=FORMAT) -logger = logging.getLogger(__name__) - - -def main(): - """ - Main evaluate function - """ - cfg = load_config(FLAGS.config) - if 'architecture' in cfg: - main_arch = cfg.architecture - else: - raise ValueError("'architecture' not specified in config file.") - - merge_config(FLAGS.opt) - - # check if set use_gpu=True in paddlepaddle cpu version - check_gpu(cfg.use_gpu) - - if cfg.use_gpu: - devices_num = fluid.core.get_cuda_device_count() - else: - devices_num = int( - os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - - if 'eval_feed' not in cfg: - eval_feed = create(main_arch + 'EvalFeed') - else: - eval_feed = create(cfg.eval_feed) - - # define executor - place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() - exe = fluid.Executor(place) - - # build program - model = create(main_arch) - startup_prog = fluid.Program() - eval_prog = fluid.Program() - with fluid.program_guard(eval_prog, startup_prog): - with fluid.unique_name.guard(): - pyreader, feed_vars = create_feed(eval_feed) - fetches = model.eval(feed_vars) - eval_prog = eval_prog.clone(True) - - reader = create_reader(eval_feed) - pyreader.decorate_sample_list_generator(reader, place) - - # compile program for multi-devices - if devices_num <= 1: - compile_program = fluid.compiler.CompiledProgram(eval_prog) - else: - build_strategy = fluid.BuildStrategy() - build_strategy.memory_optimize = False - build_strategy.enable_inplace = False - compile_program = fluid.compiler.CompiledProgram( - eval_prog).with_data_parallel(build_strategy=build_strategy) - - # load model - exe.run(startup_prog) - if 'weights' in cfg: - checkpoint.load_pretrain(exe, eval_prog, cfg.weights) - - extra_keys = [] - if 'metric' in cfg and cfg.metric == 'COCO': - extra_keys = ['im_info', 'im_id', 'im_shape'] - - keys, values, cls = parse_fetches(fetches, eval_prog, extra_keys) - - results = eval_run(exe, compile_program, pyreader, keys, values, cls) - # evaluation - resolution = None - if 'mask' in results[0]: - resolution = model.mask_head.resolution - eval_results(results, eval_feed, cfg.metric, resolution, FLAGS.output_file) - - -if __name__ == '__main__': - parser = ArgsParser() - parser.add_argument( - "-f", - "--output_file", - default=None, - type=str, - help="Evaluation file name, default to bbox.json and mask.json.") - FLAGS = parser.parse_args() - main() +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import multiprocessing + +import paddle.fluid as fluid + +import sys +sys.path.append('..') + +from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results +import ppdet.utils.checkpoint as checkpoint +from ppdet.utils.cli import ArgsParser +from ppdet.utils.check import check_gpu +from ppdet.modeling.model_input import create_feed +from ppdet.data.data_feed import create_reader +from ppdet.core.workspace import load_config, merge_config, create + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def main(): + """ + Main evaluate function + """ + cfg = load_config(FLAGS.config) + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + + # check if set use_gpu=True in paddlepaddle cpu version + check_gpu(cfg.use_gpu) + + if cfg.use_gpu: + devices_num = fluid.core.get_cuda_device_count() + else: + devices_num = int( + os.environ.get('CPU_NUM', multiprocessing.cpu_count())) + + if 'eval_feed' not in cfg: + eval_feed = create(main_arch + 'EvalFeed') + else: + eval_feed = create(cfg.eval_feed) + + # define executor + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + # build program + model = create(main_arch) + startup_prog = fluid.Program() + eval_prog = fluid.Program() + with fluid.program_guard(eval_prog, startup_prog): + with fluid.unique_name.guard(): + pyreader, feed_vars = create_feed(eval_feed) + fetches = model.eval(feed_vars) + eval_prog = eval_prog.clone(True) + + reader = create_reader(eval_feed) + pyreader.decorate_sample_list_generator(reader, place) + + # compile program for multi-devices + if devices_num <= 1: + compile_program = fluid.compiler.CompiledProgram(eval_prog) + else: + build_strategy = fluid.BuildStrategy() + build_strategy.memory_optimize = False + build_strategy.enable_inplace = False + compile_program = fluid.compiler.CompiledProgram( + eval_prog).with_data_parallel(build_strategy=build_strategy) + + # load model + exe.run(startup_prog) + if 'weights' in cfg: + checkpoint.load_pretrain(exe, eval_prog, cfg.weights) + + extra_keys = [] + if 'metric' in cfg and cfg.metric == 'COCO': + extra_keys = ['im_info', 'im_id', 'im_shape'] + + keys, values, cls = parse_fetches(fetches, eval_prog, extra_keys) + + results = eval_run(exe, compile_program, pyreader, keys, values, cls) + # evaluation + resolution = None + if 'mask' in results[0]: + resolution = model.mask_head.resolution + eval_results(results, eval_feed, cfg.metric, resolution, FLAGS.output_file) + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "-f", + "--output_file", + default=None, + type=str, + help="Evaluation file name, default to bbox.json and mask.json.") + FLAGS = parser.parse_args() + main() diff --git a/PaddleCV/PaddleDetection/tools/infer.py b/PaddleCV/PaddleDetection/tools/infer.py index 7e77315be7407cdeed15afa62d2f81bdad363563..bbe4670f6f7e7f417340d99eae7f7c6cf19f2272 100644 --- a/PaddleCV/PaddleDetection/tools/infer.py +++ b/PaddleCV/PaddleDetection/tools/infer.py @@ -1,260 +1,260 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import glob - -import numpy as np -from PIL import Image - -from paddle import fluid - -import sys -sys.path.append('..') - -from ppdet.core.workspace import load_config, merge_config, create -from ppdet.modeling.model_input import create_feed -from ppdet.data.data_feed import create_reader - -from ppdet.utils.eval_utils import parse_fetches -from ppdet.utils.cli import ArgsParser -from ppdet.utils.check import check_gpu -from ppdet.utils.visualizer import visualize_results -import ppdet.utils.checkpoint as checkpoint - -import logging -FORMAT = '%(asctime)s-%(levelname)s: %(message)s' -logging.basicConfig(level=logging.INFO, format=FORMAT) -logger = logging.getLogger(__name__) - - -def get_save_image_name(output_dir, image_path): - """ - Get save image name from source image path. - """ - if not os.path.exists(output_dir): - os.makedirs(output_dir) - image_name = image_path.split('/')[-1] - name, ext = os.path.splitext(image_name) - return os.path.join(output_dir, "{}".format(name)) + ext - - -def get_test_images(infer_dir, infer_img): - """ - Get image path list in TEST mode - """ - assert infer_img is not None or infer_dir is not None, \ - "--infer_img or --infer_dir should be set" - assert infer_img is None or os.path.isfile(infer_img), \ - "{} is not a file".format(infer_img) - assert infer_dir is None or os.path.isdir(infer_dir), \ - "{} is not a directory".format(infer_dir) - images = [] - - # infer_img has a higher priority - if infer_img and os.path.isfile(infer_img): - images.append(infer_img) - return images - - infer_dir = os.path.abspath(infer_dir) - assert os.path.isdir(infer_dir), \ - "infer_dir {} is not a directory".format(infer_dir) - exts = ['jpg', 'jpeg', 'png', 'bmp'] - exts += [ext.upper() for ext in exts] - for ext in exts: - images.extend(glob.glob('{}/*.{}'.format(infer_dir, ext))) - - assert len(images) > 0, "no image found in {}".format(infer_dir) - logger.info("Found {} inference images in total.".format(len(images))) - - return images - - -def prune_feed_vars(feeded_var_names, target_vars, prog): - """ - Filter out feed variables which are not in program, - pruned feed variables are only used in post processing - on model output, which are not used in program, such - as im_id to identify image order, im_shape to clip bbox - in image. - """ - exist_var_names = [] - prog = prog.clone() - prog = prog._prune(targets=target_vars) - global_block = prog.global_block() - for name in feeded_var_names: - try: - v = global_block.var(name) - exist_var_names.append(v.name) - except Exception: - logger.info('save_inference_model pruned unused feed ' - 'variables {}'.format(name)) - pass - return exist_var_names - - -def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): - cfg_name = os.path.basename(FLAGS.config).split('.')[0] - save_dir = os.path.join(FLAGS.output_dir, cfg_name) - feeded_var_names = [var.name for var in feed_vars.values()] - target_vars = test_fetches.values() - feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, infer_prog) - logger.info("Save inference model to {}, input: {}, output: " - "{}...".format(save_dir, feeded_var_names, - [var.name for var in target_vars])) - fluid.io.save_inference_model(save_dir, - feeded_var_names=feeded_var_names, - target_vars=target_vars, - executor=exe, - main_program=infer_prog, - params_filename="__params__") - - -def main(): - cfg = load_config(FLAGS.config) - - if 'architecture' in cfg: - main_arch = cfg.architecture - else: - raise ValueError("'architecture' not specified in config file.") - - merge_config(FLAGS.opt) - - # check if set use_gpu=True in paddlepaddle cpu version - check_gpu(cfg.use_gpu) - - if 'test_feed' not in cfg: - test_feed = create(main_arch + 'TestFeed') - else: - test_feed = create(cfg.test_feed) - - test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) - test_feed.dataset.add_images(test_images) - - place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() - exe = fluid.Executor(place) - - model = create(main_arch) - - startup_prog = fluid.Program() - infer_prog = fluid.Program() - with fluid.program_guard(infer_prog, startup_prog): - with fluid.unique_name.guard(): - _, feed_vars = create_feed(test_feed, use_pyreader=False) - test_fetches = model.test(feed_vars) - infer_prog = infer_prog.clone(True) - - reader = create_reader(test_feed) - feeder = fluid.DataFeeder(place=place, feed_list=feed_vars.values()) - - exe.run(startup_prog) - if cfg.weights: - checkpoint.load_checkpoint(exe, infer_prog, cfg.weights) - - if FLAGS.save_inference_model: - save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog) - - # parse infer fetches - extra_keys = [] - if cfg['metric'] == 'COCO': - extra_keys = ['im_info', 'im_id', 'im_shape'] - if cfg['metric'] == 'VOC': - extra_keys = ['im_id'] - keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys) - - # parse dataset category - if cfg.metric == 'COCO': - from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info - if cfg.metric == "VOC": - from ppdet.utils.voc_eval import bbox2out, get_category_info - - anno_file = getattr(test_feed.dataset, 'annotation', None) - with_background = getattr(test_feed, 'with_background', True) - use_default_label = getattr(test_feed, 'use_default_label', False) - clsid2catid, catid2name = get_category_info(anno_file, with_background, - use_default_label) - - # whether output bbox is normalized in model output layer - is_bbox_normalized = False - if hasattr(model, 'is_bbox_normalized') and \ - callable(model.is_bbox_normalized): - is_bbox_normalized = model.is_bbox_normalized() - - imid2path = reader.imid2path - for iter_id, data in enumerate(reader()): - outs = exe.run(infer_prog, - feed=feeder.feed(data), - fetch_list=values, - return_numpy=False) - res = { - k: (np.array(v), v.recursive_sequence_lengths()) - for k, v in zip(keys, outs) - } - logger.info('Infer iter {}'.format(iter_id)) - - bbox_results = None - mask_results = None - if 'bbox' in res: - bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized) - if 'mask' in res: - mask_results = mask2out([res], clsid2catid, - model.mask_head.resolution) - - # visualize result - im_ids = res['im_id'][0] - for im_id in im_ids: - image_path = imid2path[int(im_id)] - image = Image.open(image_path).convert('RGB') - image = visualize_results(image, - int(im_id), catid2name, - FLAGS.draw_threshold, bbox_results, - mask_results, is_bbox_normalized) - save_name = get_save_image_name(FLAGS.output_dir, image_path) - logger.info("Detection bbox results save in {}".format(save_name)) - image.save(save_name, quality=95) - - -if __name__ == '__main__': - parser = ArgsParser() - parser.add_argument( - "--infer_dir", - type=str, - default=None, - help="Directory for images to perform inference on.") - parser.add_argument( - "--infer_img", - type=str, - default=None, - help="Image path, has higher priority over --infer_dir") - parser.add_argument( - "--output_dir", - type=str, - default="output", - help="Directory for storing the output visualization files.") - parser.add_argument( - "--draw_threshold", - type=float, - default=0.5, - help="Threshold to reserve the result for visualization.") - parser.add_argument( - "--save_inference_model", - action='store_true', - default=False, - help="Save inference model in output_dir if True.") - FLAGS = parser.parse_args() - main() +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import glob + +import numpy as np +from PIL import Image + +from paddle import fluid + +import sys +sys.path.append('..') + +from ppdet.core.workspace import load_config, merge_config, create +from ppdet.modeling.model_input import create_feed +from ppdet.data.data_feed import create_reader + +from ppdet.utils.eval_utils import parse_fetches +from ppdet.utils.cli import ArgsParser +from ppdet.utils.check import check_gpu +from ppdet.utils.visualizer import visualize_results +import ppdet.utils.checkpoint as checkpoint + +import logging +FORMAT = '%(asctime)s-%(levelname)s: %(message)s' +logging.basicConfig(level=logging.INFO, format=FORMAT) +logger = logging.getLogger(__name__) + + +def get_save_image_name(output_dir, image_path): + """ + Get save image name from source image path. + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + image_name = image_path.split('/')[-1] + name, ext = os.path.splitext(image_name) + return os.path.join(output_dir, "{}".format(name)) + ext + + +def get_test_images(infer_dir, infer_img): + """ + Get image path list in TEST mode + """ + assert infer_img is not None or infer_dir is not None, \ + "--infer_img or --infer_dir should be set" + assert infer_img is None or os.path.isfile(infer_img), \ + "{} is not a file".format(infer_img) + assert infer_dir is None or os.path.isdir(infer_dir), \ + "{} is not a directory".format(infer_dir) + images = [] + + # infer_img has a higher priority + if infer_img and os.path.isfile(infer_img): + images.append(infer_img) + return images + + infer_dir = os.path.abspath(infer_dir) + assert os.path.isdir(infer_dir), \ + "infer_dir {} is not a directory".format(infer_dir) + exts = ['jpg', 'jpeg', 'png', 'bmp'] + exts += [ext.upper() for ext in exts] + for ext in exts: + images.extend(glob.glob('{}/*.{}'.format(infer_dir, ext))) + + assert len(images) > 0, "no image found in {}".format(infer_dir) + logger.info("Found {} inference images in total.".format(len(images))) + + return images + + +def prune_feed_vars(feeded_var_names, target_vars, prog): + """ + Filter out feed variables which are not in program, + pruned feed variables are only used in post processing + on model output, which are not used in program, such + as im_id to identify image order, im_shape to clip bbox + in image. + """ + exist_var_names = [] + prog = prog.clone() + prog = prog._prune(targets=target_vars) + global_block = prog.global_block() + for name in feeded_var_names: + try: + v = global_block.var(name) + exist_var_names.append(v.name) + except Exception: + logger.info('save_inference_model pruned unused feed ' + 'variables {}'.format(name)) + pass + return exist_var_names + + +def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): + cfg_name = os.path.basename(FLAGS.config).split('.')[0] + save_dir = os.path.join(FLAGS.output_dir, cfg_name) + feeded_var_names = [var.name for var in feed_vars.values()] + target_vars = test_fetches.values() + feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, infer_prog) + logger.info("Save inference model to {}, input: {}, output: " + "{}...".format(save_dir, feeded_var_names, + [var.name for var in target_vars])) + fluid.io.save_inference_model(save_dir, + feeded_var_names=feeded_var_names, + target_vars=target_vars, + executor=exe, + main_program=infer_prog, + params_filename="__params__") + + +def main(): + cfg = load_config(FLAGS.config) + + if 'architecture' in cfg: + main_arch = cfg.architecture + else: + raise ValueError("'architecture' not specified in config file.") + + merge_config(FLAGS.opt) + + # check if set use_gpu=True in paddlepaddle cpu version + check_gpu(cfg.use_gpu) + + if 'test_feed' not in cfg: + test_feed = create(main_arch + 'TestFeed') + else: + test_feed = create(cfg.test_feed) + + test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) + test_feed.dataset.add_images(test_images) + + place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + + model = create(main_arch) + + startup_prog = fluid.Program() + infer_prog = fluid.Program() + with fluid.program_guard(infer_prog, startup_prog): + with fluid.unique_name.guard(): + _, feed_vars = create_feed(test_feed, use_pyreader=False) + test_fetches = model.test(feed_vars) + infer_prog = infer_prog.clone(True) + + reader = create_reader(test_feed) + feeder = fluid.DataFeeder(place=place, feed_list=feed_vars.values()) + + exe.run(startup_prog) + if cfg.weights: + checkpoint.load_checkpoint(exe, infer_prog, cfg.weights) + + if FLAGS.save_inference_model: + save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog) + + # parse infer fetches + extra_keys = [] + if cfg['metric'] == 'COCO': + extra_keys = ['im_info', 'im_id', 'im_shape'] + if cfg['metric'] == 'VOC': + extra_keys = ['im_id'] + keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys) + + # parse dataset category + if cfg.metric == 'COCO': + from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info + if cfg.metric == "VOC": + from ppdet.utils.voc_eval import bbox2out, get_category_info + + anno_file = getattr(test_feed.dataset, 'annotation', None) + with_background = getattr(test_feed, 'with_background', True) + use_default_label = getattr(test_feed, 'use_default_label', False) + clsid2catid, catid2name = get_category_info(anno_file, with_background, + use_default_label) + + # whether output bbox is normalized in model output layer + is_bbox_normalized = False + if hasattr(model, 'is_bbox_normalized') and \ + callable(model.is_bbox_normalized): + is_bbox_normalized = model.is_bbox_normalized() + + imid2path = reader.imid2path + for iter_id, data in enumerate(reader()): + outs = exe.run(infer_prog, + feed=feeder.feed(data), + fetch_list=values, + return_numpy=False) + res = { + k: (np.array(v), v.recursive_sequence_lengths()) + for k, v in zip(keys, outs) + } + logger.info('Infer iter {}'.format(iter_id)) + + bbox_results = None + mask_results = None + if 'bbox' in res: + bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized) + if 'mask' in res: + mask_results = mask2out([res], clsid2catid, + model.mask_head.resolution) + + # visualize result + im_ids = res['im_id'][0] + for im_id in im_ids: + image_path = imid2path[int(im_id)] + image = Image.open(image_path).convert('RGB') + image = visualize_results(image, + int(im_id), catid2name, + FLAGS.draw_threshold, bbox_results, + mask_results, is_bbox_normalized) + save_name = get_save_image_name(FLAGS.output_dir, image_path) + logger.info("Detection bbox results save in {}".format(save_name)) + image.save(save_name, quality=95) + + +if __name__ == '__main__': + parser = ArgsParser() + parser.add_argument( + "--infer_dir", + type=str, + default=None, + help="Directory for images to perform inference on.") + parser.add_argument( + "--infer_img", + type=str, + default=None, + help="Image path, has higher priority over --infer_dir") + parser.add_argument( + "--output_dir", + type=str, + default="output", + help="Directory for storing the output visualization files.") + parser.add_argument( + "--draw_threshold", + type=float, + default=0.5, + help="Threshold to reserve the result for visualization.") + parser.add_argument( + "--save_inference_model", + action='store_true', + default=False, + help="Save inference model in output_dir if True.") + FLAGS = parser.parse_args() + main()