提交 9be01b91 编写于 作者: E edencfc

fix the import path

上级 a488cf69
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import re
import sys
from argparse import ArgumentParser, RawDescriptionHelpFormatter
import yaml
import sys
sys.path.append('..')
from ppdet.core.workspace import get_registered_modules, load_config
from ppdet.utils.cli import ColorTTY
color_tty = ColorTTY()
MISC_CONFIG = {
"architecture": "<value>",
"max_iters": "<value>",
"train_feed": "<value>",
"eval_feed": "<value>",
"test_feed": "<value>",
"pretrain_weights": "<value>",
"save_dir": "<value>",
"weights": "<value>",
"metric": "<value>",
"log_smooth_window": 20,
"snapshot_iter": 10000,
"use_gpu": True,
}
def dump_value(value):
# XXX this is hackish, but collections.abc is not available in python 2
if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)):
value = yaml.dump(value, default_flow_style=True)
value = value.replace('\n', '')
value = value.replace('...', '')
return "'{}'".format(value)
else:
# primitive types
return str(value)
def dump_config(module, minimal=False):
args = module.schema.values()
if minimal:
args = [arg for arg in args if not arg.has_default()]
return yaml.dump(
{
module.name: {
arg.name: arg.default if arg.has_default() else "<value>"
for arg in args
}
},
default_flow_style=False,
default_style='')
def list_modules(**kwargs):
target_category = kwargs['category']
module_schema = get_registered_modules()
module_by_category = {}
for schema in module_schema.values():
category = schema.category
if target_category is not None and schema.category != target_category:
continue
if category not in module_by_category:
module_by_category[category] = [schema]
else:
module_by_category[category].append(schema)
for cat, modules in module_by_category.items():
print("Available modules in the category '{}':".format(cat))
print("")
max_len = max([len(mod.name) for mod in modules])
for mod in modules:
print(color_tty.green(mod.name.ljust(max_len)),
mod.doc.split('\n')[0])
print("")
def help_module(**kwargs):
schema = get_registered_modules()[kwargs['module']]
doc = schema.doc is None and "Not documented" or "{}".format(schema.doc)
func_args = {arg.name: arg.doc for arg in schema.schema.values()}
max_len = max([len(k) for k in func_args.keys()])
opts = "\n".join([
"{} {}".format(color_tty.green(k.ljust(max_len)), v)
for k, v in func_args.items()
])
template = dump_config(schema)
print("{}\n\n{}\n\n{}\n\n{}\n\n{}\n\n{}\n{}\n".format(
color_tty.bold(color_tty.blue("MODULE DESCRIPTION:")),
doc,
color_tty.bold(color_tty.blue("MODULE OPTIONS:")),
opts,
color_tty.bold(color_tty.blue("CONFIGURATION TEMPLATE:")),
template,
color_tty.bold(color_tty.blue("COMMAND LINE OPTIONS:")), ))
for arg in schema.schema.values():
print("--opt {}.{}={}".format(schema.name, arg.name,
dump_value(arg.default)
if arg.has_default() else "<value>"))
def generate_config(**kwargs):
minimal = kwargs['minimal']
modules = kwargs['modules']
module_schema = get_registered_modules()
visited = []
schema = []
def walk(m):
if m in visited:
return
s = module_schema[m]
schema.append(s)
visited.append(m)
for mod in modules:
walk(mod)
# XXX try to be smart about when to add header,
# if any "architecture" module, is included, head will be added as well
if any([getattr(m, 'category', None) == 'architecture' for m in schema]):
# XXX for ordered printing
header = ""
for k, v in MISC_CONFIG.items():
header += yaml.dump(
{
k: v
}, default_flow_style=False, default_style='')
print(header)
for s in schema:
print(dump_config(s, minimal))
# FIXME this is pretty hackish, maybe implement a custom YAML printer?
def analyze_config(**kwargs):
config = load_config(kwargs['file'])
modules = get_registered_modules()
green = '___{}___'.format(color_tty.colors.index('green') + 31)
styled = {}
for key in config.keys():
if not config[key]: # empty schema
continue
if key not in modules and not hasattr(config[key], '__dict__'):
styled[key] = config[key]
continue
elif key in modules:
module = modules[key]
else:
type_name = type(config[key]).__name__
if type_name in modules:
module = modules[type_name].copy()
module.update({
k: v
for k, v in config[key].__dict__.items()
if k in module.schema
})
key += " ({})".format(type_name)
default = module.find_default_keys()
missing = module.find_missing_keys()
mismatch = module.find_mismatch_keys()
extra = module.find_extra_keys()
dep_missing = []
for dep in module.inject:
if isinstance(module[dep], str) and module[dep] != '<value>':
if module[dep] not in modules: # not a valid module
dep_missing.append(dep)
else:
dep_mod = modules[module[dep]]
# empty dict but mandatory
if not dep_mod and dep_mod.mandatory():
dep_missing.append(dep)
override = list(
set(module.keys()) - set(default) - set(extra) - set(dep_missing))
replacement = {}
for name in set(override + default + extra + mismatch + missing):
new_name = name
if name in missing:
value = "<missing>"
else:
value = module[name]
if name in extra:
value = dump_value(value) + " <extraneous>"
elif name in mismatch:
value = dump_value(value) + " <type mismatch>"
elif name in dep_missing:
value = dump_value(value) + " <module config missing>"
elif name in override and value != '<missing>':
mark = green
new_name = mark + name
replacement[new_name] = value
styled[key] = replacement
buffer = yaml.dump(styled, default_flow_style=False, default_style='')
buffer = (re.sub(r"<missing>", r"<missing>", buffer))
buffer = (re.sub(r"<extraneous>", r"<extraneous>", buffer))
buffer = (re.sub(r"<type mismatch>", r"<type mismatch>", buffer))
buffer = (re.sub(r"<module config missing>",
r"<module config missing>", buffer))
buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2:", buffer)
print(buffer)
if __name__ == '__main__':
argv = sys.argv[1:]
parser = ArgumentParser(formatter_class=RawDescriptionHelpFormatter)
subparsers = parser.add_subparsers(help='Supported Commands')
list_parser = subparsers.add_parser("list", help="list available modules")
help_parser = subparsers.add_parser(
"help", help="show detail options for module")
generate_parser = subparsers.add_parser(
"generate", help="generate configuration template")
analyze_parser = subparsers.add_parser(
"analyze", help="analyze configuration file")
list_parser.set_defaults(func=list_modules)
help_parser.set_defaults(func=help_module)
generate_parser.set_defaults(func=generate_config)
analyze_parser.set_defaults(func=analyze_config)
list_group = list_parser.add_mutually_exclusive_group()
list_group.add_argument(
"-c",
"--category",
type=str,
default=None,
help="list modules for <category>")
help_parser.add_argument(
"module",
help="module to show info for",
choices=list(get_registered_modules().keys()))
generate_parser.add_argument(
"modules",
nargs='+',
help="include these module in generated configuration template",
choices=list(get_registered_modules().keys()))
generate_group = generate_parser.add_mutually_exclusive_group()
generate_group.add_argument(
"--minimal", action='store_true', help="only include required options")
generate_group.add_argument(
"--full",
action='store_false',
dest='minimal',
help="include all options")
analyze_parser.add_argument("file", help="configuration file to analyze")
if len(sys.argv) < 2:
parser.print_help()
sys.exit(1)
args = parser.parse_args(argv)
if hasattr(args, 'func'):
args.func(**vars(args))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import re
import sys
from argparse import ArgumentParser, RawDescriptionHelpFormatter
import yaml
import sys
sys.path.append('..')
from ppdet.core.workspace import get_registered_modules, load_config
from ppdet.utils.cli import ColorTTY
color_tty = ColorTTY()
MISC_CONFIG = {
"architecture": "<value>",
"max_iters": "<value>",
"train_feed": "<value>",
"eval_feed": "<value>",
"test_feed": "<value>",
"pretrain_weights": "<value>",
"save_dir": "<value>",
"weights": "<value>",
"metric": "<value>",
"log_smooth_window": 20,
"snapshot_iter": 10000,
"use_gpu": True,
}
def dump_value(value):
# XXX this is hackish, but collections.abc is not available in python 2
if hasattr(value, '__dict__') or isinstance(value, (dict, tuple, list)):
value = yaml.dump(value, default_flow_style=True)
value = value.replace('\n', '')
value = value.replace('...', '')
return "'{}'".format(value)
else:
# primitive types
return str(value)
def dump_config(module, minimal=False):
args = module.schema.values()
if minimal:
args = [arg for arg in args if not arg.has_default()]
return yaml.dump(
{
module.name: {
arg.name: arg.default if arg.has_default() else "<value>"
for arg in args
}
},
default_flow_style=False,
default_style='')
def list_modules(**kwargs):
target_category = kwargs['category']
module_schema = get_registered_modules()
module_by_category = {}
for schema in module_schema.values():
category = schema.category
if target_category is not None and schema.category != target_category:
continue
if category not in module_by_category:
module_by_category[category] = [schema]
else:
module_by_category[category].append(schema)
for cat, modules in module_by_category.items():
print("Available modules in the category '{}':".format(cat))
print("")
max_len = max([len(mod.name) for mod in modules])
for mod in modules:
print(color_tty.green(mod.name.ljust(max_len)),
mod.doc.split('\n')[0])
print("")
def help_module(**kwargs):
schema = get_registered_modules()[kwargs['module']]
doc = schema.doc is None and "Not documented" or "{}".format(schema.doc)
func_args = {arg.name: arg.doc for arg in schema.schema.values()}
max_len = max([len(k) for k in func_args.keys()])
opts = "\n".join([
"{} {}".format(color_tty.green(k.ljust(max_len)), v)
for k, v in func_args.items()
])
template = dump_config(schema)
print("{}\n\n{}\n\n{}\n\n{}\n\n{}\n\n{}\n{}\n".format(
color_tty.bold(color_tty.blue("MODULE DESCRIPTION:")),
doc,
color_tty.bold(color_tty.blue("MODULE OPTIONS:")),
opts,
color_tty.bold(color_tty.blue("CONFIGURATION TEMPLATE:")),
template,
color_tty.bold(color_tty.blue("COMMAND LINE OPTIONS:")), ))
for arg in schema.schema.values():
print("--opt {}.{}={}".format(schema.name, arg.name,
dump_value(arg.default)
if arg.has_default() else "<value>"))
def generate_config(**kwargs):
minimal = kwargs['minimal']
modules = kwargs['modules']
module_schema = get_registered_modules()
visited = []
schema = []
def walk(m):
if m in visited:
return
s = module_schema[m]
schema.append(s)
visited.append(m)
for mod in modules:
walk(mod)
# XXX try to be smart about when to add header,
# if any "architecture" module, is included, head will be added as well
if any([getattr(m, 'category', None) == 'architecture' for m in schema]):
# XXX for ordered printing
header = ""
for k, v in MISC_CONFIG.items():
header += yaml.dump(
{
k: v
}, default_flow_style=False, default_style='')
print(header)
for s in schema:
print(dump_config(s, minimal))
# FIXME this is pretty hackish, maybe implement a custom YAML printer?
def analyze_config(**kwargs):
config = load_config(kwargs['file'])
modules = get_registered_modules()
green = '___{}___'.format(color_tty.colors.index('green') + 31)
styled = {}
for key in config.keys():
if not config[key]: # empty schema
continue
if key not in modules and not hasattr(config[key], '__dict__'):
styled[key] = config[key]
continue
elif key in modules:
module = modules[key]
else:
type_name = type(config[key]).__name__
if type_name in modules:
module = modules[type_name].copy()
module.update({
k: v
for k, v in config[key].__dict__.items()
if k in module.schema
})
key += " ({})".format(type_name)
default = module.find_default_keys()
missing = module.find_missing_keys()
mismatch = module.find_mismatch_keys()
extra = module.find_extra_keys()
dep_missing = []
for dep in module.inject:
if isinstance(module[dep], str) and module[dep] != '<value>':
if module[dep] not in modules: # not a valid module
dep_missing.append(dep)
else:
dep_mod = modules[module[dep]]
# empty dict but mandatory
if not dep_mod and dep_mod.mandatory():
dep_missing.append(dep)
override = list(
set(module.keys()) - set(default) - set(extra) - set(dep_missing))
replacement = {}
for name in set(override + default + extra + mismatch + missing):
new_name = name
if name in missing:
value = "<missing>"
else:
value = module[name]
if name in extra:
value = dump_value(value) + " <extraneous>"
elif name in mismatch:
value = dump_value(value) + " <type mismatch>"
elif name in dep_missing:
value = dump_value(value) + " <module config missing>"
elif name in override and value != '<missing>':
mark = green
new_name = mark + name
replacement[new_name] = value
styled[key] = replacement
buffer = yaml.dump(styled, default_flow_style=False, default_style='')
buffer = (re.sub(r"<missing>", r"<missing>", buffer))
buffer = (re.sub(r"<extraneous>", r"<extraneous>", buffer))
buffer = (re.sub(r"<type mismatch>", r"<type mismatch>", buffer))
buffer = (re.sub(r"<module config missing>",
r"<module config missing>", buffer))
buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2:", buffer)
print(buffer)
if __name__ == '__main__':
argv = sys.argv[1:]
parser = ArgumentParser(formatter_class=RawDescriptionHelpFormatter)
subparsers = parser.add_subparsers(help='Supported Commands')
list_parser = subparsers.add_parser("list", help="list available modules")
help_parser = subparsers.add_parser(
"help", help="show detail options for module")
generate_parser = subparsers.add_parser(
"generate", help="generate configuration template")
analyze_parser = subparsers.add_parser(
"analyze", help="analyze configuration file")
list_parser.set_defaults(func=list_modules)
help_parser.set_defaults(func=help_module)
generate_parser.set_defaults(func=generate_config)
analyze_parser.set_defaults(func=analyze_config)
list_group = list_parser.add_mutually_exclusive_group()
list_group.add_argument(
"-c",
"--category",
type=str,
default=None,
help="list modules for <category>")
help_parser.add_argument(
"module",
help="module to show info for",
choices=list(get_registered_modules().keys()))
generate_parser.add_argument(
"modules",
nargs='+',
help="include these module in generated configuration template",
choices=list(get_registered_modules().keys()))
generate_group = generate_parser.add_mutually_exclusive_group()
generate_group.add_argument(
"--minimal", action='store_true', help="only include required options")
generate_group.add_argument(
"--full",
action='store_false',
dest='minimal',
help="include all options")
analyze_parser.add_argument("file", help="configuration file to analyze")
if len(sys.argv) < 2:
parser.print_help()
sys.exit(1)
args = parser.parse_args(argv)
if hasattr(args, 'func'):
args.func(**vars(args))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import multiprocessing
import paddle.fluid as fluid
import sys
sys.path.append('..')
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
import ppdet.utils.checkpoint as checkpoint
from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu
from ppdet.modeling.model_input import create_feed
from ppdet.data.data_feed import create_reader
from ppdet.core.workspace import load_config, merge_config, create
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def main():
"""
Main evaluate function
"""
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
main_arch = cfg.architecture
else:
raise ValueError("'architecture' not specified in config file.")
merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count()
else:
devices_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
if 'eval_feed' not in cfg:
eval_feed = create(main_arch + 'EvalFeed')
else:
eval_feed = create(cfg.eval_feed)
# define executor
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# build program
model = create(main_arch)
startup_prog = fluid.Program()
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
pyreader, feed_vars = create_feed(eval_feed)
fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True)
reader = create_reader(eval_feed)
pyreader.decorate_sample_list_generator(reader, place)
# compile program for multi-devices
if devices_num <= 1:
compile_program = fluid.compiler.CompiledProgram(eval_prog)
else:
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
compile_program = fluid.compiler.CompiledProgram(
eval_prog).with_data_parallel(build_strategy=build_strategy)
# load model
exe.run(startup_prog)
if 'weights' in cfg:
checkpoint.load_pretrain(exe, eval_prog, cfg.weights)
extra_keys = []
if 'metric' in cfg and cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
keys, values, cls = parse_fetches(fetches, eval_prog, extra_keys)
results = eval_run(exe, compile_program, pyreader, keys, values, cls)
# evaluation
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
eval_results(results, eval_feed, cfg.metric, resolution, FLAGS.output_file)
if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"-f",
"--output_file",
default=None,
type=str,
help="Evaluation file name, default to bbox.json and mask.json.")
FLAGS = parser.parse_args()
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import multiprocessing
import paddle.fluid as fluid
import sys
sys.path.append('..')
from ppdet.utils.eval_utils import parse_fetches, eval_run, eval_results
import ppdet.utils.checkpoint as checkpoint
from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu
from ppdet.modeling.model_input import create_feed
from ppdet.data.data_feed import create_reader
from ppdet.core.workspace import load_config, merge_config, create
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def main():
"""
Main evaluate function
"""
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
main_arch = cfg.architecture
else:
raise ValueError("'architecture' not specified in config file.")
merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
if cfg.use_gpu:
devices_num = fluid.core.get_cuda_device_count()
else:
devices_num = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
if 'eval_feed' not in cfg:
eval_feed = create(main_arch + 'EvalFeed')
else:
eval_feed = create(cfg.eval_feed)
# define executor
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
# build program
model = create(main_arch)
startup_prog = fluid.Program()
eval_prog = fluid.Program()
with fluid.program_guard(eval_prog, startup_prog):
with fluid.unique_name.guard():
pyreader, feed_vars = create_feed(eval_feed)
fetches = model.eval(feed_vars)
eval_prog = eval_prog.clone(True)
reader = create_reader(eval_feed)
pyreader.decorate_sample_list_generator(reader, place)
# compile program for multi-devices
if devices_num <= 1:
compile_program = fluid.compiler.CompiledProgram(eval_prog)
else:
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
compile_program = fluid.compiler.CompiledProgram(
eval_prog).with_data_parallel(build_strategy=build_strategy)
# load model
exe.run(startup_prog)
if 'weights' in cfg:
checkpoint.load_pretrain(exe, eval_prog, cfg.weights)
extra_keys = []
if 'metric' in cfg and cfg.metric == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
keys, values, cls = parse_fetches(fetches, eval_prog, extra_keys)
results = eval_run(exe, compile_program, pyreader, keys, values, cls)
# evaluation
resolution = None
if 'mask' in results[0]:
resolution = model.mask_head.resolution
eval_results(results, eval_feed, cfg.metric, resolution, FLAGS.output_file)
if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"-f",
"--output_file",
default=None,
type=str,
help="Evaluation file name, default to bbox.json and mask.json.")
FLAGS = parser.parse_args()
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import glob
import numpy as np
from PIL import Image
from paddle import fluid
import sys
sys.path.append('..')
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.modeling.model_input import create_feed
from ppdet.data.data_feed import create_reader
from ppdet.utils.eval_utils import parse_fetches
from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu
from ppdet.utils.visualizer import visualize_results
import ppdet.utils.checkpoint as checkpoint
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def get_save_image_name(output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = image_path.split('/')[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer_img or --infer_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
images = []
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
images.append(infer_img)
return images
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.extend(glob.glob('{}/*.{}'.format(infer_dir, ext)))
assert len(images) > 0, "no image found in {}".format(infer_dir)
logger.info("Found {} inference images in total.".format(len(images)))
return images
def prune_feed_vars(feeded_var_names, target_vars, prog):
"""
Filter out feed variables which are not in program,
pruned feed variables are only used in post processing
on model output, which are not used in program, such
as im_id to identify image order, im_shape to clip bbox
in image.
"""
exist_var_names = []
prog = prog.clone()
prog = prog._prune(targets=target_vars)
global_block = prog.global_block()
for name in feeded_var_names:
try:
v = global_block.var(name)
exist_var_names.append(v.name)
except Exception:
logger.info('save_inference_model pruned unused feed '
'variables {}'.format(name))
pass
return exist_var_names
def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name)
feeded_var_names = [var.name for var in feed_vars.values()]
target_vars = test_fetches.values()
feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, infer_prog)
logger.info("Save inference model to {}, input: {}, output: "
"{}...".format(save_dir, feeded_var_names,
[var.name for var in target_vars]))
fluid.io.save_inference_model(save_dir,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
executor=exe,
main_program=infer_prog,
params_filename="__params__")
def main():
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
main_arch = cfg.architecture
else:
raise ValueError("'architecture' not specified in config file.")
merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
if 'test_feed' not in cfg:
test_feed = create(main_arch + 'TestFeed')
else:
test_feed = create(cfg.test_feed)
test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
test_feed.dataset.add_images(test_images)
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
model = create(main_arch)
startup_prog = fluid.Program()
infer_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard():
_, feed_vars = create_feed(test_feed, use_pyreader=False)
test_fetches = model.test(feed_vars)
infer_prog = infer_prog.clone(True)
reader = create_reader(test_feed)
feeder = fluid.DataFeeder(place=place, feed_list=feed_vars.values())
exe.run(startup_prog)
if cfg.weights:
checkpoint.load_checkpoint(exe, infer_prog, cfg.weights)
if FLAGS.save_inference_model:
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
# parse infer fetches
extra_keys = []
if cfg['metric'] == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg['metric'] == 'VOC':
extra_keys = ['im_id']
keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)
# parse dataset category
if cfg.metric == 'COCO':
from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info
if cfg.metric == "VOC":
from ppdet.utils.voc_eval import bbox2out, get_category_info
anno_file = getattr(test_feed.dataset, 'annotation', None)
with_background = getattr(test_feed, 'with_background', True)
use_default_label = getattr(test_feed, 'use_default_label', False)
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
# whether output bbox is normalized in model output layer
is_bbox_normalized = False
if hasattr(model, 'is_bbox_normalized') and \
callable(model.is_bbox_normalized):
is_bbox_normalized = model.is_bbox_normalized()
imid2path = reader.imid2path
for iter_id, data in enumerate(reader()):
outs = exe.run(infer_prog,
feed=feeder.feed(data),
fetch_list=values,
return_numpy=False)
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs)
}
logger.info('Infer iter {}'.format(iter_id))
bbox_results = None
mask_results = None
if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
if 'mask' in res:
mask_results = mask2out([res], clsid2catid,
model.mask_head.resolution)
# visualize result
im_ids = res['im_id'][0]
for im_id in im_ids:
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
image = visualize_results(image,
int(im_id), catid2name,
FLAGS.draw_threshold, bbox_results,
mask_results, is_bbox_normalized)
save_name = get_save_image_name(FLAGS.output_dir, image_path)
logger.info("Detection bbox results save in {}".format(save_name))
image.save(save_name, quality=95)
if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Directory for images to perform inference on.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path, has higher priority over --infer_dir")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory for storing the output visualization files.")
parser.add_argument(
"--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
parser.add_argument(
"--save_inference_model",
action='store_true',
default=False,
help="Save inference model in output_dir if True.")
FLAGS = parser.parse_args()
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import glob
import numpy as np
from PIL import Image
from paddle import fluid
import sys
sys.path.append('..')
from ppdet.core.workspace import load_config, merge_config, create
from ppdet.modeling.model_input import create_feed
from ppdet.data.data_feed import create_reader
from ppdet.utils.eval_utils import parse_fetches
from ppdet.utils.cli import ArgsParser
from ppdet.utils.check import check_gpu
from ppdet.utils.visualizer import visualize_results
import ppdet.utils.checkpoint as checkpoint
import logging
FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logger = logging.getLogger(__name__)
def get_save_image_name(output_dir, image_path):
"""
Get save image name from source image path.
"""
if not os.path.exists(output_dir):
os.makedirs(output_dir)
image_name = image_path.split('/')[-1]
name, ext = os.path.splitext(image_name)
return os.path.join(output_dir, "{}".format(name)) + ext
def get_test_images(infer_dir, infer_img):
"""
Get image path list in TEST mode
"""
assert infer_img is not None or infer_dir is not None, \
"--infer_img or --infer_dir should be set"
assert infer_img is None or os.path.isfile(infer_img), \
"{} is not a file".format(infer_img)
assert infer_dir is None or os.path.isdir(infer_dir), \
"{} is not a directory".format(infer_dir)
images = []
# infer_img has a higher priority
if infer_img and os.path.isfile(infer_img):
images.append(infer_img)
return images
infer_dir = os.path.abspath(infer_dir)
assert os.path.isdir(infer_dir), \
"infer_dir {} is not a directory".format(infer_dir)
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for ext in exts:
images.extend(glob.glob('{}/*.{}'.format(infer_dir, ext)))
assert len(images) > 0, "no image found in {}".format(infer_dir)
logger.info("Found {} inference images in total.".format(len(images)))
return images
def prune_feed_vars(feeded_var_names, target_vars, prog):
"""
Filter out feed variables which are not in program,
pruned feed variables are only used in post processing
on model output, which are not used in program, such
as im_id to identify image order, im_shape to clip bbox
in image.
"""
exist_var_names = []
prog = prog.clone()
prog = prog._prune(targets=target_vars)
global_block = prog.global_block()
for name in feeded_var_names:
try:
v = global_block.var(name)
exist_var_names.append(v.name)
except Exception:
logger.info('save_inference_model pruned unused feed '
'variables {}'.format(name))
pass
return exist_var_names
def save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog):
cfg_name = os.path.basename(FLAGS.config).split('.')[0]
save_dir = os.path.join(FLAGS.output_dir, cfg_name)
feeded_var_names = [var.name for var in feed_vars.values()]
target_vars = test_fetches.values()
feeded_var_names = prune_feed_vars(feeded_var_names, target_vars, infer_prog)
logger.info("Save inference model to {}, input: {}, output: "
"{}...".format(save_dir, feeded_var_names,
[var.name for var in target_vars]))
fluid.io.save_inference_model(save_dir,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
executor=exe,
main_program=infer_prog,
params_filename="__params__")
def main():
cfg = load_config(FLAGS.config)
if 'architecture' in cfg:
main_arch = cfg.architecture
else:
raise ValueError("'architecture' not specified in config file.")
merge_config(FLAGS.opt)
# check if set use_gpu=True in paddlepaddle cpu version
check_gpu(cfg.use_gpu)
if 'test_feed' not in cfg:
test_feed = create(main_arch + 'TestFeed')
else:
test_feed = create(cfg.test_feed)
test_images = get_test_images(FLAGS.infer_dir, FLAGS.infer_img)
test_feed.dataset.add_images(test_images)
place = fluid.CUDAPlace(0) if cfg.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
model = create(main_arch)
startup_prog = fluid.Program()
infer_prog = fluid.Program()
with fluid.program_guard(infer_prog, startup_prog):
with fluid.unique_name.guard():
_, feed_vars = create_feed(test_feed, use_pyreader=False)
test_fetches = model.test(feed_vars)
infer_prog = infer_prog.clone(True)
reader = create_reader(test_feed)
feeder = fluid.DataFeeder(place=place, feed_list=feed_vars.values())
exe.run(startup_prog)
if cfg.weights:
checkpoint.load_checkpoint(exe, infer_prog, cfg.weights)
if FLAGS.save_inference_model:
save_infer_model(FLAGS, exe, feed_vars, test_fetches, infer_prog)
# parse infer fetches
extra_keys = []
if cfg['metric'] == 'COCO':
extra_keys = ['im_info', 'im_id', 'im_shape']
if cfg['metric'] == 'VOC':
extra_keys = ['im_id']
keys, values, _ = parse_fetches(test_fetches, infer_prog, extra_keys)
# parse dataset category
if cfg.metric == 'COCO':
from ppdet.utils.coco_eval import bbox2out, mask2out, get_category_info
if cfg.metric == "VOC":
from ppdet.utils.voc_eval import bbox2out, get_category_info
anno_file = getattr(test_feed.dataset, 'annotation', None)
with_background = getattr(test_feed, 'with_background', True)
use_default_label = getattr(test_feed, 'use_default_label', False)
clsid2catid, catid2name = get_category_info(anno_file, with_background,
use_default_label)
# whether output bbox is normalized in model output layer
is_bbox_normalized = False
if hasattr(model, 'is_bbox_normalized') and \
callable(model.is_bbox_normalized):
is_bbox_normalized = model.is_bbox_normalized()
imid2path = reader.imid2path
for iter_id, data in enumerate(reader()):
outs = exe.run(infer_prog,
feed=feeder.feed(data),
fetch_list=values,
return_numpy=False)
res = {
k: (np.array(v), v.recursive_sequence_lengths())
for k, v in zip(keys, outs)
}
logger.info('Infer iter {}'.format(iter_id))
bbox_results = None
mask_results = None
if 'bbox' in res:
bbox_results = bbox2out([res], clsid2catid, is_bbox_normalized)
if 'mask' in res:
mask_results = mask2out([res], clsid2catid,
model.mask_head.resolution)
# visualize result
im_ids = res['im_id'][0]
for im_id in im_ids:
image_path = imid2path[int(im_id)]
image = Image.open(image_path).convert('RGB')
image = visualize_results(image,
int(im_id), catid2name,
FLAGS.draw_threshold, bbox_results,
mask_results, is_bbox_normalized)
save_name = get_save_image_name(FLAGS.output_dir, image_path)
logger.info("Detection bbox results save in {}".format(save_name))
image.save(save_name, quality=95)
if __name__ == '__main__':
parser = ArgsParser()
parser.add_argument(
"--infer_dir",
type=str,
default=None,
help="Directory for images to perform inference on.")
parser.add_argument(
"--infer_img",
type=str,
default=None,
help="Image path, has higher priority over --infer_dir")
parser.add_argument(
"--output_dir",
type=str,
default="output",
help="Directory for storing the output visualization files.")
parser.add_argument(
"--draw_threshold",
type=float,
default=0.5,
help="Threshold to reserve the result for visualization.")
parser.add_argument(
"--save_inference_model",
action='store_true',
default=False,
help="Save inference model in output_dir if True.")
FLAGS = parser.parse_args()
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册