未验证 提交 b2d0da6a 编写于 作者: R ruri 提交者: GitHub

refine infer in image classification (#4077)

上级 4ffbe264
...@@ -34,7 +34,7 @@ parser = argparse.ArgumentParser(description=__doc__) ...@@ -34,7 +34,7 @@ parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser) add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable # yapf: disable
add_arg('data_dir', str, "./data/ILSVRC2012/", "The ImageNet datset") add_arg('data_dir', str, "./data/ILSVRC2012/", "The ImageNet datset")
add_arg('batch_size', int, 256, "batch size on the all devices.") add_arg('batch_size', int, 256, "batch size on all the devices.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.") add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('class_dim', int, 1000, "Class number.") add_arg('class_dim', int, 1000, "Class number.")
parser.add_argument("--pretrained_model", default=None, required=True, type=str, help="The path to load pretrained model") parser.add_argument("--pretrained_model", default=None, required=True, type=str, help="The path to load pretrained model")
......
...@@ -23,6 +23,7 @@ import math ...@@ -23,6 +23,7 @@ import math
import numpy as np import numpy as np
import argparse import argparse
import functools import functools
import re
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -51,7 +52,7 @@ add_arg('interpolation', int, None, "The interpolation mode" ...@@ -51,7 +52,7 @@ add_arg('interpolation', int, None, "The interpolation mode"
add_arg('padding_type', str, "SAME", "Padding type of convolution") add_arg('padding_type', str, "SAME", "Padding type of convolution")
add_arg('use_se', bool, True, "Whether to use Squeeze-and-Excitation module for EfficientNet.") add_arg('use_se', bool, True, "Whether to use Squeeze-and-Excitation module for EfficientNet.")
add_arg('image_path', str, None, "single image path") add_arg('image_path', str, None, "single image path")
add_arg('batch_size', int, 8, "batch_size on all devices") add_arg('batch_size', int, 8, "batch_size on all the devices")
add_arg('save_json_path', str, None, "save output to a json file") add_arg('save_json_path', str, None, "save output to a json file")
# yapf: enable # yapf: enable
...@@ -101,8 +102,9 @@ def infer(args): ...@@ -101,8 +102,9 @@ def infer(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
places = fluid.framework.cuda_places()
compiled_program = fluid.compiler.CompiledProgram( compiled_program = fluid.compiler.CompiledProgram(
test_program).with_data_parallel() test_program).with_data_parallel(places=places)
fluid.io.load_persistables(exe, args.pretrained_model) fluid.io.load_persistables(exe, args.pretrained_model)
if args.save_inference: if args.save_inference:
...@@ -119,42 +121,68 @@ def infer(args): ...@@ -119,42 +121,68 @@ def infer(args):
imagenet_reader = reader.ImageNetReader() imagenet_reader = reader.ImageNetReader()
test_reader = imagenet_reader.test(settings=args) test_reader = imagenet_reader.test(settings=args)
feeder = fluid.DataFeeder(place=places, feed_list=[image])
feeder = fluid.DataFeeder(place=place, feed_list=[image])
test_reader = feeder.decorate_reader(test_reader, multi_devices=True)
TOPK = args.topk TOPK = args.topk
if os.path.exists(args.class_map_path): if os.path.exists(args.class_map_path):
print("The map of readable label and numerical label has been found!") print("The map of readable label and numerical label has been found!")
f = open(args.class_map_path) with open(args.class_map_path) as f:
label_dict = {} label_dict = {}
for item in f.readlines(): strinfo = re.compile(r"\d+ ")
key = item.split(" ")[0] for item in f.readlines():
value = [l.replace("\n", "") for l in item.split(" ")[1:]] key = item.split(" ")[0]
label_dict[key] = value value = [
strinfo.sub("", l).replace("\n", "")
for l in item.split(", ")
]
label_dict[key] = value
info = {}
parallel_data = []
parallel_id = []
place_num = paddle.fluid.core.get_cuda_device_count()
for batch_id, data in enumerate(test_reader()): for batch_id, data in enumerate(test_reader()):
result = exe.run(compiled_program, fetch_list=fetch_list, feed=data) image_data = [[items[0]] for items in data]
result = result[0][0] image_id = [items[1] for items in data]
pred_label = np.argsort(result)[::-1][:TOPK]
parallel_id.append(image_id)
if os.path.exists(args.class_map_path): parallel_data.append(image_data)
readable_pred_label = []
for label in pred_label: if place_num == len(parallel_data):
readable_pred_label.append(label_dict[str(label)]) result = exe.run(
print(readable_pred_label) compiled_program,
info = "Test-{0}-score: {1}, class{2} {3}".format( fetch_list=fetch_list,
batch_id, result[pred_label], pred_label, readable_pred_label) feed=list(feeder.feed_parallel(parallel_data, place_num)))
else: for i, res in enumerate(result[0]):
info = "Test-{0}-score: {1}, class{2}".format( pred_label = np.argsort(res)[::-1][:TOPK]
batch_id, result[pred_label], pred_label) real_id = str(np.array(parallel_id).flatten()[i])
print(info) _, real_id = os.path.split(real_id)
if args.save_json_path:
save_json(info, args.save_json_path) if os.path.exists(args.class_map_path):
readable_pred_label = []
sys.stdout.flush() for label in pred_label:
if args.image_path: readable_pred_label.append(label_dict[str(label)])
os.remove(".tmp.txt")
info[real_id] = {}
info[real_id]['score'], info[real_id]['class'], info[
real_id]['class_name'] = str(res[pred_label]), str(
pred_label), readable_pred_label
else:
info[real_id] = {}
info[real_id]['score'], info[real_id]['class'] = str(res[
pred_label]), str(pred_label)
print(real_id, info[real_id])
sys.stdout.flush()
if args.save_json_path:
save_json(info, args.save_json_path)
parallel_data = []
parallel_id = []
if args.image_path:
os.remove(".tmp.txt")
def main(): def main():
......
...@@ -240,7 +240,7 @@ def process_image(sample, settings, mode, color_jitter, rotate): ...@@ -240,7 +240,7 @@ def process_image(sample, settings, mode, color_jitter, rotate):
if mode == 'train' or mode == 'val': if mode == 'train' or mode == 'val':
return (img, sample[1]) return (img, sample[1])
elif mode == 'test': elif mode == 'test':
return (img, ) return (img, sample[0])
def process_batch_data(input_data, settings, mode, color_jitter, rotate): def process_batch_data(input_data, settings, mode, color_jitter, rotate):
...@@ -262,6 +262,23 @@ class ImageNetReader: ...@@ -262,6 +262,23 @@ class ImageNetReader:
assert isinstance(seed, int), "shuffle seed must be int" assert isinstance(seed, int), "shuffle seed must be int"
self.shuffle_seed = seed self.shuffle_seed = seed
def _get_single_card_bs(self, settings, mode):
if settings.use_gpu:
if mode == "val" and settings.test_batch_size:
single_card_bs = settings.test_batch_size // paddle.fluid.core.get_cuda_device_count(
)
else:
single_card_bs = settings.batch_size // paddle.fluid.core.get_cuda_device_count(
)
else:
if mode == "val" and settings.test_batch_size:
single_card_bs = settings.test_batch_size // int(
os.environ.get('CPU_NUM', 1))
else:
single_card_bs = settings.batch_size // int(
os.environ.get('CPU_NUM', 1))
return single_card_bs
def _reader_creator(self, def _reader_creator(self,
settings, settings,
file_list, file_list,
...@@ -272,12 +289,7 @@ class ImageNetReader: ...@@ -272,12 +289,7 @@ class ImageNetReader:
data_dir=None): data_dir=None):
num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) num_trainers = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
if settings.use_gpu: batch_size = self._get_single_card_bs(settings, mode)
batch_size = settings.batch_size // paddle.fluid.core.get_cuda_device_count(
)
else:
batch_size = settings.batch_size // int(
os.environ.get('CPU_NUM', 1))
def reader(): def reader():
def read_file_list(): def read_file_list():
...@@ -304,12 +316,15 @@ class ImageNetReader: ...@@ -304,12 +316,15 @@ class ImageNetReader:
full_lines = [] full_lines = []
for i in range(settings.same_feed): for i in range(settings.same_feed):
full_lines.append(temp_file) full_lines.append(temp_file)
for line in full_lines: for line in full_lines:
img_path, label = line.split() img_path, label = line.split()
img_path = os.path.join(data_dir, img_path) img_path = os.path.join(data_dir, img_path)
batch_data.append([img_path, int(label)]) batch_data.append([img_path, int(label)])
if len(batch_data) == batch_size: if len(batch_data) == batch_size:
if mode == 'train' or mode == 'val' or mode == 'test': if mode == 'train' or mode == 'val' or mode == 'test':
yield batch_data yield batch_data
batch_data = [] batch_data = []
......
...@@ -102,8 +102,8 @@ def parse_args(): ...@@ -102,8 +102,8 @@ def parse_args():
parser.add_argument('--image_shape', nargs='+', type=int, default=[3, 224, 224], help="The shape of image") parser.add_argument('--image_shape', nargs='+', type=int, default=[3, 224, 224], help="The shape of image")
add_arg('num_epochs', int, 120, "The number of total epochs.") add_arg('num_epochs', int, 120, "The number of total epochs.")
add_arg('class_dim', int, 1000, "The number of total classes.") add_arg('class_dim', int, 1000, "The number of total classes.")
add_arg('batch_size', int, 8, "Minibatch size on all devices.") add_arg('batch_size', int, 8, "Minibatch size on all the devices.")
add_arg('test_batch_size', int, 16, "Test batch size on all devices.") add_arg('test_batch_size', int, None, "Test batch size on all the devices.")
add_arg('lr', float, 0.1, "The learning rate.") add_arg('lr', float, 0.1, "The learning rate.")
add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.")
add_arg('l2_decay', float, 1e-4, "The l2_decay parameter.") add_arg('l2_decay', float, 1e-4, "The l2_decay parameter.")
...@@ -287,10 +287,35 @@ def init_model(exe, args, program): ...@@ -287,10 +287,35 @@ def init_model(exe, args, program):
print("Finish initing model from %s" % (args.checkpoint)) print("Finish initing model from %s" % (args.checkpoint))
if args.pretrained_model: if args.pretrained_model:
# yapf: disable
#XXX: should rename all models' final fc layers name as final_fc_weights and final_fc_offset!
final_fc_name = [
"fc8_weights","fc8_offset", #alexnet
"fc_weights","fc_offset", #darknet, densenet, dpn, hrnet, mobilenet_v3, res2net, res2net_vd, resnext, resnext_vd, xception
#efficient
"out","out_offset", "out1","out1_offset", "out2","out2_offset", #googlenet
"final_fc_weights", "final_fc_offset", #inception_v4
"fc7_weights", "fc7_offset", #mobilenetv1
"fc10_weights", "fc10_offset", #mobilenetv2
"fc_0", #resnet, resnet_vc, resnet_vd
"fc.weight", "fc.bias", #resnext101_wsl
"fc6_weights", "fc6_offset", #se_resnet_vd, se_resnext, se_resnext_vd, shufflenet_v2, shufflenet_v2_swish,
#squeezenet
"fc8_weights", "fc8_offset", #vgg
"fc_bias" #"fc_weights", xception_deeplab
]
# yapf: enable
def is_parameter(var): def is_parameter(var):
return isinstance(var, fluid.framework.Parameter) and ( fc_exclude_flag = False
not ("fc_0" in var.name)) and os.path.exists( for item in final_fc_name:
if item in var.name:
fc_exclude_flag = True
return isinstance(
var, fluid.framework.
Parameter) and not fc_exclude_flag and os.path.exists(
os.path.join(args.pretrained_model, var.name)) os.path.join(args.pretrained_model, var.name))
print("Load pretrain weights from {}, exclude fc layer.".format( print("Load pretrain weights from {}, exclude fc layer.".format(
...@@ -314,7 +339,7 @@ def save_model(args, exe, train_prog, info): ...@@ -314,7 +339,7 @@ def save_model(args, exe, train_prog, info):
def save_json(info, path): def save_json(info, path):
""" save eval result or infer result to file as json format. """ save eval result or infer result to file as json format.
""" """
with open(path, 'a') as f: with open(path, 'w') as f:
json.dump(info, f) json.dump(info, f)
...@@ -493,7 +518,7 @@ def best_strategy_compiled(args, ...@@ -493,7 +518,7 @@ def best_strategy_compiled(args,
exec_strategy.num_threads = 1 exec_strategy.num_threads = 1
compiled_program = fluid.CompiledProgram(program).with_data_parallel( compiled_program = fluid.CompiledProgram(program).with_data_parallel(
loss_name=loss.name if mode == "train" else loss, loss_name=loss.name if mode == "train" else None,
share_vars_from=share_prog if mode == "val" else None, share_vars_from=share_prog if mode == "val" else None,
build_strategy=build_strategy, build_strategy=build_strategy,
exec_strategy=exec_strategy) exec_strategy=exec_strategy)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册