提交 9be01b91 编写于 作者: E edencfc

fix the import path

上级 a488cf69
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import re import re
import sys import sys
from argparse import ArgumentParser, RawDescriptionHelpFormatter from argparse import ArgumentParser, RawDescriptionHelpFormatter
import yaml import yaml
import sys import sys
sys.path.append('..') sys.path.append('..')
from ppdet.core.workspace import get_registered_modules, load_config from ppdet.core.workspace import get_registered_modules, load_config
from ppdet.utils.cli import ColorTTY from ppdet.utils.cli import ColorTTY
color_tty = ColorTTY() color_tty = ColorTTY()
MISC_CONFIG = { MISC_CONFIG = {
"architecture": "<value>", "architecture": "<value>",
"max_iters": "<value>", "max_iters": "<value>",
"train_feed": "<value>", "train_feed": "<value>",
"eval_feed": "<value>", "eval_feed": "<value>",
"test_feed": "<value>", "test_feed": "<value>",
"pretrain_weights": "<value>", "pretrain_weights": "<value>",
"save_dir": "<value>", "save_dir": "<value>",
"weights": "<value>", "weights": "<value>",
"metric": "<value>", "metric": "<value>",
"log_smooth_window": 20, "log_smooth_window": 20,
"snapshot_iter": 10000, "snapshot_iter": 10000,
"use_gpu": True, "use_gpu": True,
} }
def dump_value(value): def dump_value(value):
# XXX this is hackish, but collections.abc is not available in python 2 # XXX this is hackish, but collections.abc is not available in python 2
if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)): if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)):
value = yaml.dump(value, default_flow_style=True) value = yaml.dump(value, default_flow_style=True)
value = value.replace('\n', '') value = value.replace('\n', '')
value = value.replace('...', '') value = value.replace('...', '')
return "'{}'".format(value) return "'{}'".format(value)
else: else:
# primitive types # primitive types
return str(value) return str(value)
def dump_config(module, minimal=False): def dump_config(module, minimal=False):
args = module.schema.values() args = module.schema.values()
if minimal: if minimal:
args = [arg for arg in args if not arg.has_default()] args = [arg for arg in args if not arg.has_default()]
return yaml.dump( return yaml.dump(
{ {
module.name: { module.name: {
arg.name: arg.default if arg.has_default() else "<value>" arg.name: arg.default if arg.has_default() else "<value>"
for arg in args for arg in args
} }
}, },
default_flow_style=False, default_flow_style=False,
default_style='') default_style='')
def list_modules(**kwargs): def list_modules(**kwargs):
target_category = kwargs['category'] target_category = kwargs['category']
module_schema = get_registered_modules() module_schema = get_registered_modules()
module_by_category = {} module_by_category = {}
for schema in module_schema.values(): for schema in module_schema.values():
category = schema.category category = schema.category
if target_category is not None and schema.category != target_category: if target_category is not None and schema.category != target_category:
continue continue
if category not in module_by_category: if category not in module_by_category:
module_by_category[category] = [schema] module_by_category[category] = [schema]
else: else:
module_by_category[category].append(schema) module_by_category[category].append(schema)
for cat, modules in module_by_category.items(): for cat, modules in module_by_category.items():
print("Available modules in the category '{}':".format(cat)) print("Available modules in the category '{}':".format(cat))
print("") print("")
max_len = max([len(mod.name) for mod in modules]) max_len = max([len(mod.name) for mod in modules])
for mod in modules: for mod in modules:
print(color_tty.green(mod.name.ljust(max_len)), print(color_tty.green(mod.name.ljust(max_len)),
mod.doc.split('\n')[0]) mod.doc.split('\n')[0])
print("") print("")
def help_module(**kwargs): def help_module(**kwargs):
schema = get_registered_modules()[kwargs['module']] schema = get_registered_modules()[kwargs['module']]
doc = schema.doc is None and "Not documented" or "{}".format(schema.doc) doc = schema.doc is None and "Not documented" or "{}".format(schema.doc)
func_args = {arg.name: arg.doc for arg in schema.schema.values()} func_args = {arg.name: arg.doc for arg in schema.schema.values()}
max_len = max([len(k) for k in func_args.keys()]) max_len = max([len(k) for k in func_args.keys()])
opts = "\n".join([ opts = "\n".join([
"{} {}".format(color_tty.green(k.ljust(max_len)), v) "{} {}".format(color_tty.green(k.ljust(max_len)), v)
for k, v in func_args.items() for k, v in func_args.items()
]) ])
template = dump_config(schema) template = dump_config(schema)
print("{}\n\n{}\n\n{}\n\n{}\n\n{}\n\n{}\n{}\n".format( print("{}\n\n{}\n\n{}\n\n{}\n\n{}\n\n{}\n{}\n".format(
color_tty.bold(color_tty.blue("MODULE DESCRIPTION:")), color_tty.bold(color_tty.blue("MODULE DESCRIPTION:")),
doc, doc,
color_tty.bold(color_tty.blue("MODULE OPTIONS:")), color_tty.bold(color_tty.blue("MODULE OPTIONS:")),
opts, opts,
color_tty.bold(color_tty.blue("CONFIGURATION TEMPLATE:")), color_tty.bold(color_tty.blue("CONFIGURATION TEMPLATE:")),
template, template,
color_tty.bold(color_tty.blue("COMMAND LINE OPTIONS:")), )) color_tty.bold(color_tty.blue("COMMAND LINE OPTIONS:")), ))
for arg in schema.schema.values(): for arg in schema.schema.values():
print("--opt {}.{}={}".format(schema.name, arg.name, print("--opt {}.{}={}".format(schema.name, arg.name,
dump_value(arg.default) dump_value(arg.default)
if arg.has_default() else "<value>")) if arg.has_default() else "<value>"))
def generate_config(**kwargs): def generate_config(**kwargs):
minimal = kwargs['minimal'] minimal = kwargs['minimal']
modules = kwargs['modules'] modules = kwargs['modules']
module_schema = get_registered_modules() module_schema = get_registered_modules()
visited = [] visited = []
schema = [] schema = []
def walk(m): def walk(m):
if m in visited: if m in visited:
return return
s = module_schema[m] s = module_schema[m]
schema.append(s) schema.append(s)
visited.append(m) visited.append(m)
for mod in modules: for mod in modules:
walk(mod) walk(mod)
# XXX try to be smart about when to add header, # XXX try to be smart about when to add header,
# if any "architecture" module, is included, head will be added as well # if any "architecture" module, is included, head will be added as well
if any([getattr(m, 'category', None) == 'architecture' for m in schema]): if any([getattr(m, 'category', None) == 'architecture' for m in schema]):
# XXX for ordered printing # XXX for ordered printing
header = "" header = ""
for k, v in MISC_CONFIG.items(): for k, v in MISC_CONFIG.items():
header += yaml.dump( header += yaml.dump(
{ {
k: v k: v
}, default_flow_style=False, default_style='') }, default_flow_style=False, default_style='')
print(header) print(header)
for s in schema: for s in schema:
print(dump_config(s, minimal)) print(dump_config(s, minimal))
# FIXME this is pretty hackish, maybe implement a custom YAML printer? # FIXME this is pretty hackish, maybe implement a custom YAML printer?
def analyze_config(**kwargs): def analyze_config(**kwargs):
config = load_config(kwargs['file']) config = load_config(kwargs['file'])
modules = get_registered_modules() modules = get_registered_modules()
green = '___{}___'.format(color_tty.colors.index('green') + 31) green = '___{}___'.format(color_tty.colors.index('green') + 31)
styled = {} styled = {}
for key in config.keys(): for key in config.keys():
if not config[key]: # empty schema if not config[key]: # empty schema
continue continue
if key not in modules and not hasattr(config[key], '__dict__'): if key not in modules and not hasattr(config[key], '__dict__'):
styled[key] = config[key] styled[key] = config[key]
continue continue
elif key in modules: elif key in modules:
module = modules[key] module = modules[key]
else: else:
type_name = type(config[key]).__name__ type_name = type(config[key]).__name__
if type_name in modules: if type_name in modules:
module = modules[type_name].copy() module = modules[type_name].copy()
module.update({ module.update({
k: v k: v
for k, v in config[key].__dict__.items() for k, v in config[key].__dict__.items()
if k in module.schema if k in module.schema
}) })
key += " ({})".format(type_name) key += " ({})".format(type_name)
default = module.find_default_keys() default = module.find_default_keys()
missing = module.find_missing_keys() missing = module.find_missing_keys()
mismatch = module.find_mismatch_keys() mismatch = module.find_mismatch_keys()
extra = module.find_extra_keys() extra = module.find_extra_keys()
dep_missing = [] dep_missing = []
for dep in module.inject: for dep in module.inject:
if isinstance(module[dep], str) and module[dep] != '<value>': if isinstance(module[dep], str) and module[dep] != '<value>':
if module[dep] not in modules: # not a valid module if module[dep] not in modules: # not a valid module
dep_missing.append(dep) dep_missing.append(dep)
else: else:
dep_mod = modules[module[dep]] dep_mod = modules[module[dep]]
# empty dict but mandatory # empty dict but mandatory
if not dep_mod and dep_mod.mandatory(): if not dep_mod and dep_mod.mandatory():
dep_missing.append(dep) dep_missing.append(dep)
override = list( override = list(
set(module.keys()) - set(default) - set(extra) - set(dep_missing)) set(module.keys()) - set(default) - set(extra) - set(dep_missing))
replacement = {} replacement = {}
for name in set(override + default + extra + mismatch + missing): for name in set(override + default + extra + mismatch + missing):
new_name = name new_name = name
if name in missing: if name in missing:
value = "<missing>" value = "<missing>"
else: else:
value = module[name] value = module[name]
if name in extra: if name in extra:
value = dump_value(value) + " <extraneous>" value = dump_value(value) + " <extraneous>"
elif name in mismatch: elif name in mismatch:
value = dump_value(value) + " <type mismatch>" value = dump_value(value) + " <type mismatch>"
elif name in dep_missing: elif name in dep_missing:
value = dump_value(value) + " <module config missing>" value = dump_value(value) + " <module config missing>"
elif name in override and value != '<missing>': elif name in override and value != '<missing>':
mark = green mark = green
new_name = mark + name new_name = mark + name
replacement[new_name] = value replacement[new_name] = value
styled[key] = replacement styled[key] = replacement
buffer = yaml.dump(styled, default_flow_style=False, default_style='') buffer = yaml.dump(styled, default_flow_style=False, default_style='')
buffer = (re.sub(r"<missing>", r"<missing>", buffer)) buffer = (re.sub(r"<missing>", r"<missing>", buffer))
buffer = (re.sub(r"<extraneous>", r"<extraneous>", buffer)) buffer = (re.sub(r"<extraneous>", r"<extraneous>", buffer))
buffer = (re.sub(r"<type mismatch>", r"<type mismatch>", buffer)) buffer = (re.sub(r"<type mismatch>", r"<type mismatch>", buffer))
buffer = (re.sub(r"<module config missing>", buffer = (re.sub(r"<module config missing>",
r"<module config missing>", buffer)) r"<module config missing>", buffer))
buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2:", buffer) buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2:", buffer)
print(buffer) print(buffer)
if __name__ == '__main__': if __name__ == '__main__':
argv = sys.argv[1:] argv = sys.argv[1:]
parser = ArgumentParser(formatter_class=RawDescriptionHelpFormatter) parser = ArgumentParser(formatter_class=RawDescriptionHelpFormatter)
subparsers = parser.add_subparsers(help='Supported Commands') subparsers = parser.add_subparsers(help='Supported Commands')
list_parser = subparsers.add_parser("list", help="list available modules") list_parser = subparsers.add_parser("list", help="list available modules")
help_parser = subparsers.add_parser( help_parser = subparsers.add_parser(
"help", help="show detail options for module") "help", help="show detail options for module")
generate_parser = subparsers.add_parser( generate_parser = subparsers.add_parser(
"generate", help="generate configuration template") "generate", help="generate configuration template")
analyze_parser = subparsers.add_parser( analyze_parser = subparsers.add_parser(
"analyze", help="analyze configuration file") "analyze", help="analyze configuration file")
list_parser.set_defaults(func=list_modules) list_parser.set_defaults(func=list_modules)
help_parser.set_defaults(func=help_module) help_parser.set_defaults(func=help_module)
generate_parser.set_defaults(func=generate_config) generate_parser.set_defaults(func=generate_config)
analyze_parser.set_defaults(func=analyze_config) analyze_parser.set_defaults(func=analyze_config)
list_group = list_parser.add_mutually_exclusive_group() list_group = list_parser.add_mutually_exclusive_group()
list_group.add_argument( list_group.add_argument(
"-c", "-c",
"--category", "--category",
type=str, type=str,
default=None, default=None,
help="list modules for <category>") help="list modules for <category>")
help_parser.add_argument( help_parser.add_argument(
"module", "module",
help="module to show info for", help="module to show info for",
choices=list(get_registered_modules().keys())) choices=list(get_registered_modules().keys()))
generate_parser.add_argument( generate_parser.add_argument(
"modules", "modules",
nargs='+', nargs='+',
help="include these module in generated configuration template", help="include these module in generated configuration template",
choices=list(get_registered_modules().keys())) choices=list(get_registered_modules().keys()))
generate_group = generate_parser.add_mutually_exclusive_group() generate_group = generate_parser.add_mutually_exclusive_group()
generate_group.add_argument( generate_group.add_argument(
"--minimal", action='store_true', help="only include required options") "--minimal", action='store_true', help="only include required options")
generate_group.add_argument( generate_group.add_argument(
"--full", "--full",
action='store_false', action='store_false',
dest='minimal', dest='minimal',
help="include all options") help="include all options")
analyze_parser.add_argument("file", help="configuration file to analyze") analyze_parser.add_argument("file", help="configuration file to analyze")
if len(sys.argv) < 2: if len(sys.argv) < 2:
parser.print_help() parser.print_help()
sys.exit(1) sys.exit(1)
args = parser.parse_args(argv) args = parser.parse_args(argv)
if hasattr(args, 'func'): if hasattr(args, 'func'):
args.func(**vars(args)) args.func(**vars(args))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import multiprocessing import multiprocessing
import paddle.fluid as fluid import paddle.fluid as fluid
import sys import sys
sys.path.append('..') sys.path.append('..')
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
import ppdet.utils.checkpoint as checkpoint import ppdet.utils.checkpoint as checkpoint
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu from ppdet.utils.check import check_gpu
from ppdet.modeling.model_input import create_feed from ppdet.modeling.model_input import create_feed
from ppdet.data.data_feed import create_reader from ppdet.data.data_feed import create_reader
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT) logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def main(): def main():
""" """
Main evaluate function Main evaluate function
""" """
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
if 'architecture' in cfg: if 'architecture' in cfg:
main_arch = cfg.architecture main_arch = cfg.architecture
else: else:
raise ValueError("'architecture' not specified in config file.") raise ValueError("'architecture' not specified in config file.")
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu) check_gpu(cfg.use_gpu)
if cfg.use_gpu: if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count() devices_num = fluid.core.get_cuda_device_count()
else: else:
devices_num = int( devices_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
if 'eval_feed' not in cfg: if 'eval_feed' not in cfg:
eval_feed = create(main_arch + 'EvalFeed') eval_feed = create(main_arch + 'EvalFeed')
else: else:
eval_feed = create(cfg.eval_feed) eval_feed = create(cfg.eval_feed)
# define executor # define executor
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# build program # build program
model = create(main_arch) model = create(main_arch)
startup_prog = fluid.Program() startup_prog = fluid.Program()
eval_prog = fluid.Program() eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog): with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
pyreader, feed_vars = create_feed(eval_feed) pyreader, feed_vars = create_feed(eval_feed)
fetches = model.eval(feed_vars) fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True) eval_prog = eval_prog.clone(True)
reader = create_reader(eval_feed) reader = create_reader(eval_feed)
pyreader.decorate_sample_list_generator(reader, place) pyreader.decorate_sample_list_generator(reader, place)
# compile program for multi-devices # compile program for multi-devices
if devices_num <= 1: if devices_num <= 1:
compile_program = fluid.compiler.CompiledProgram(eval_prog) compile_program = fluid.compiler.CompiledProgram(eval_prog)
else: else:
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False build_strategy.memory_optimize = False
build_strategy.enable_inplace = False build_strategy.enable_inplace = False
compile_program = fluid.compiler.CompiledProgram( compile_program = fluid.compiler.CompiledProgram(
eval_prog).with_data_parallel(build_strategy=build_strategy) eval_prog).with_data_parallel(build_strategy=build_strategy)
# load model # load model
exe.run(startup_prog) exe.run(startup_prog)
if 'weights' in cfg: if 'weights' in cfg:
checkpoint.load_pretrain(exe, eval_prog, cfg.weights) checkpoint.load_pretrain(exe, eval_prog, cfg.weights)
extra_keys = [] extra_keys = []
if 'metric' in cfg and cfg.metric == 'COCO': if 'metric' in cfg and cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape'] extra_keys = ['im_info', 'im_id', 'im_shape']
keys, values, cls = parse_fetches(fetches, eval_prog, extra_keys) keys, values, cls = parse_fetches(fetches, eval_prog, extra_keys)
results = eval_run(exe, compile_program, pyreader, keys, values, cls) results = eval_run(exe, compile_program, pyreader, keys, values, cls)
# evaluation # evaluation
resolution = None resolution = None
if 'mask' in results[0]: if 'mask' in results[0]:
resolution = model.mask_head.resolution resolution = model.mask_head.resolution
eval_results(results, eval_feed, cfg.metric, resolution, FLAGS.output_file) eval_results(results, eval_feed, cfg.metric, resolution, FLAGS.output_file)
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgsParser() parser = ArgsParser()
parser.add_argument( parser.add_argument(
"-f", "-f",
"--output_file", "--output_file",
default=None, default=None,
type=str, type=str,
help="Evaluation file name, default to bbox.json and mask.json.") help="Evaluation file name, default to bbox.json and mask.json.")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
# #
# http://www.apache.org/licenses/LICENSE-2.0 # http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software # Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import glob import glob
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from paddle import fluid from paddle import fluid
import sys import sys
sys.path.append('..') sys.path.append('..')
from ppdet.core.workspace import load_config, merge_config, create from ppdet.core.workspace import load_config, merge_config, create
from ppdet.modeling.model_input import create_feed from ppdet.modeling.model_input import create_feed
from ppdet.data.data_feed import create_reader from ppdet.data.data_feed import create_reader
from ppdet.utils.eval_utils import parse_fetches from ppdet.utils.eval_utils import parse_fetches
from ppdet.utils.cli import ArgsParser from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu from ppdet.utils.check import check_gpu
from ppdet.utils.visualizer import visualize_results from ppdet.utils.visualizer import visualize_results
import ppdet.utils.checkpoint as checkpoint import ppdet.utils.checkpoint as checkpoint
import logging import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s' FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT) logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def get_save_image_name(output_dir, image_path): def get_save_image_name(output_dir, image_path):
""" """
Get save image name from source image path. Get save image name from source image path.
""" """
if not os.path.exists(output_dir): if not os.path.exists(output_dir):
os.makedirs(output_dir) os.makedirs(output_dir)
image_name = image_path.split('/')[-1] image_name = image_path.split('/')[-1]
name, ext = os.path.splitext(image_name) name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext return os.path.join(output_dir, "{}".format(name)) + ext
def get_test_images(infer_dir, infer_img): def get_test_images(infer_dir, infer_img):
""" """
Get image path list in TEST mode Get image path list in TEST mode
""" """
assert infer_img is not None or infer_dir is not None, \ assert infer_img is not None or infer_dir is not None, \
"--infer_img or --infer_dir should be set" "--infer_img or --infer_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \ assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img) "{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \ assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir) "{} is not a directory".format(infer_dir)
images = [] images = []
# infer_img has a higher priority # infer_img has a higher priority
if infer_img and os.path.isfile(infer_img): if infer_img and os.path.isfile(infer_img):
images.append(infer_img) images.append(infer_img)
return images return images
infer_dir = os.path.abspath(infer_dir) infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \ assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir) "infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp'] exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts] exts += [ext.upper() for ext in exts]
for ext in exts: for ext in exts:
images.extend(glob.glob('{}/*.{}'.format(infer_dir, ext))) images.extend(glob.glob('{}/*.{}'.format(infer_dir, ext)))
assert len(images) > 0, "no image found in {}".format(infer_dir) assert len(images) > 0, "no image found in {}".format(infer_dir)
logger.info("Found {} inference images in total.".format(len(images))) logger.info("Found {} inference images in total.".format(len(images)))
return images return images
def prune_feed_vars(feeded_var_names, target_vars, prog): def prune_feed_vars(feeded_var_names, target_vars, prog):
""" """
Filter out feed variables which are not in program, Filter out feed variables which are not in program,
pruned feed variables are only used in post processing pruned feed variables are only used in post processing
on model output, which are not used in program, such on model output, which are not used in program, such
as im_id to identify image order, im_shape to clip bbox as im_id to identify image order, im_shape to clip bbox
in image. in image.
""" """
exist_var_names = [] exist_var_names = []
prog = prog.clone() prog = prog.clone()
prog = prog._prune(targets=target_vars) prog = prog._prune(targets=target_vars)
global_block = prog.global_block() global_block = prog.global_block()
for name in feeded_var_names: for name in feeded_var_names:
try: try:
v = global_block.var(name) v = global_block.var(name)
exist_var_names.append(v.name) exist_var_names.append(v.name)
except Exception: except Exception:
logger.info('save_inference_model pruned unused feed ' logger.info('save_inference_model pruned unused feed '
'variables {}'.format(name)) 'variables {}'.format(name))
pass pass
return exist_var_names return exist_var_names
def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog): def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
cfg_name = os.path.basename(FLAGS.config).split('.')[0] cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name) save_dir = os.path.join(FLAGS.output_dir, cfg_name)
feeded_var_names = [var.name for var in feed_vars.values()] feeded_var_names = [var.name for var in feed_vars.values()]
target_vars = test_fetches.values() target_vars = test_fetches.values()
feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, infer_prog) feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, infer_prog)
logger.info("Save inference model to {}, input: {}, output: " logger.info("Save inference model to {}, input: {}, output: "
"{}...".format(save_dir, feeded_var_names, "{}...".format(save_dir, feeded_var_names,
[var.name for var in target_vars])) [var.name for var in target_vars]))
fluid.io.save_inference_model(save_dir, fluid.io.save_inference_model(save_dir,
feeded_var_names=feeded_var_names, feeded_var_names=feeded_var_names,
target_vars=target_vars, target_vars=target_vars,
executor=exe, executor=exe,
main_program=infer_prog, main_program=infer_prog,
params_filename="__params__") params_filename="__params__")
def main(): def main():
cfg = load_config(FLAGS.config) cfg = load_config(FLAGS.config)
if 'architecture' in cfg: if 'architecture' in cfg:
main_arch = cfg.architecture main_arch = cfg.architecture
else: else:
raise ValueError("'architecture' not specified in config file.") raise ValueError("'architecture' not specified in config file.")
merge_config(FLAGS.opt) merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version # check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu) check_gpu(cfg.use_gpu)
if 'test_feed' not in cfg: if 'test_feed' not in cfg:
test_feed = create(main_arch + 'TestFeed') test_feed = create(main_arch + 'TestFeed')
else: else:
test_feed = create(cfg.test_feed) test_feed = create(cfg.test_feed)
test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img) test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
test_feed.dataset.add_images(test_images) test_feed.dataset.add_images(test_images)
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
model = create(main_arch) model = create(main_arch)
startup_prog = fluid.Program() startup_prog = fluid.Program()
infer_prog = fluid.Program() infer_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog): with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
_, feed_vars = create_feed(test_feed, use_pyreader=False) _, feed_vars = create_feed(test_feed, use_pyreader=False)
test_fetches = model.test(feed_vars) test_fetches = model.test(feed_vars)
infer_prog = infer_prog.clone(True) infer_prog = infer_prog.clone(True)
reader = create_reader(test_feed) reader = create_reader(test_feed)
feeder = fluid.DataFeeder(place=place, feed_list=feed_vars.values()) feeder = fluid.DataFeeder(place=place, feed_list=feed_vars.values())
exe.run(startup_prog) exe.run(startup_prog)
if cfg.weights: if cfg.weights:
checkpoint.load_checkpoint(exe, infer_prog, cfg.weights) checkpoint.load_checkpoint(exe, infer_prog, cfg.weights)
if FLAGS.save_inference_model: if FLAGS.save_inference_model:
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog) save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
# parse infer fetches # parse infer fetches
extra_keys = [] extra_keys = []
if cfg['metric'] == 'COCO': if cfg['metric'] == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape'] extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg['metric'] == 'VOC': if cfg['metric'] == 'VOC':
extra_keys = ['im_id'] extra_keys = ['im_id']
keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys) keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)
# parse dataset category # parse dataset category
if cfg.metric == 'COCO': if cfg.metric == 'COCO':
from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info
if cfg.metric == "VOC": if cfg.metric == "VOC":
from ppdet.utils.voc_eval import bbox2out, get_category_info from ppdet.utils.voc_eval import bbox2out, get_category_info
anno_file = getattr(test_feed.dataset, 'annotation', None) anno_file = getattr(test_feed.dataset, 'annotation', None)
with_background = getattr(test_feed, 'with_background', True) with_background = getattr(test_feed, 'with_background', True)
use_default_label = getattr(test_feed, 'use_default_label', False) use_default_label = getattr(test_feed, 'use_default_label', False)
clsid2catid, catid2name = get_category_info(anno_file, with_background, clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label) use_default_label)
# whether output bbox is normalized in model output layer # whether output bbox is normalized in model output layer
is_bbox_normalized = False is_bbox_normalized = False
if hasattr(model, 'is_bbox_normalized') and \ if hasattr(model, 'is_bbox_normalized') and \
callable(model.is_bbox_normalized): callable(model.is_bbox_normalized):
is_bbox_normalized = model.is_bbox_normalized() is_bbox_normalized = model.is_bbox_normalized()
imid2path = reader.imid2path imid2path = reader.imid2path
for iter_id, data in enumerate(reader()): for iter_id, data in enumerate(reader()):
outs = exe.run(infer_prog, outs = exe.run(infer_prog,
feed=feeder.feed(data), feed=feeder.feed(data),
fetch_list=values, fetch_list=values,
return_numpy=False) return_numpy=False)
res = { res = {
k: (np.array(v), v.recursive_sequence_lengths()) k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs) for k, v in zip(keys, outs)
} }
logger.info('Infer iter {}'.format(iter_id)) logger.info('Infer iter {}'.format(iter_id))
bbox_results = None bbox_results = None
mask_results = None mask_results = None
if 'bbox' in res: if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized) bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
if 'mask' in res: if 'mask' in res:
mask_results = mask2out([res], clsid2catid, mask_results = mask2out([res], clsid2catid,
model.mask_head.resolution) model.mask_head.resolution)
# visualize result # visualize result
im_ids = res['im_id'][0] im_ids = res['im_id'][0]
for im_id in im_ids: for im_id in im_ids:
image_path = imid2path[int(im_id)] image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB') image = Image.open(image_path).convert('RGB')
image = visualize_results(image, image = visualize_results(image,
int(im_id), catid2name, int(im_id), catid2name,
FLAGS.draw_threshold, bbox_results, FLAGS.draw_threshold, bbox_results,
mask_results, is_bbox_normalized) mask_results, is_bbox_normalized)
save_name = get_save_image_name(FLAGS.output_dir, image_path) save_name = get_save_image_name(FLAGS.output_dir, image_path)
logger.info("Detection bbox results save in {}".format(save_name)) logger.info("Detection bbox results save in {}".format(save_name))
image.save(save_name, quality=95) image.save(save_name, quality=95)
if __name__ == '__main__': if __name__ == '__main__':
parser = ArgsParser() parser = ArgsParser()
parser.add_argument( parser.add_argument(
"--infer_dir", "--infer_dir",
type=str, type=str,
default=None, default=None,
help="Directory for images to perform inference on.") help="Directory for images to perform inference on.")
parser.add_argument( parser.add_argument(
"--infer_img", "--infer_img",
type=str, type=str,
default=None, default=None,
help="Image path, has higher priority over --infer_dir") help="Image path, has higher priority over --infer_dir")
parser.add_argument( parser.add_argument(
"--output_dir", "--output_dir",
type=str, type=str,
default="output", default="output",
help="Directory for storing the output visualization files.") help="Directory for storing the output visualization files.")
parser.add_argument( parser.add_argument(
"--draw_threshold", "--draw_threshold",
type=float, type=float,
default=0.5, default=0.5,
help="Threshold to reserve the result for visualization.") help="Threshold to reserve the result for visualization.")
parser.add_argument( parser.add_argument(
"--save_inference_model", "--save_inference_model",
action='store_true', action='store_true',
default=False, default=False,
help="Save inference model in output_dir if True.") help="Save inference model in output_dir if True.")
FLAGS = parser.parse_args() FLAGS = parser.parse_args()
main() main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册